mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
85 Commits
v5.9.0
...
v5.10.0dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d06cc71cd9 | ||
|
|
ecbc1cf85d | ||
|
|
dcad306aee | ||
|
|
bea9d037bc | ||
|
|
eed4260975 | ||
|
|
6a79b1c64c | ||
|
|
4cdfe3e30d | ||
|
|
df294db236 | ||
|
|
52e247cfe0 | ||
|
|
0a13640bf3 | ||
|
|
643b71f56c | ||
|
|
b745411866 | ||
|
|
c3ffb0feed | ||
|
|
f53ff5fa3c | ||
|
|
b2337b56bd | ||
|
|
fb777b4502 | ||
|
|
4c12f5a011 | ||
|
|
d4655ea21a | ||
|
|
752b62d0b5 | ||
|
|
9a0efb308d | ||
|
|
b4c276b50f | ||
|
|
db03c196a1 | ||
|
|
6bc36b697d | ||
|
|
b7d71d3028 | ||
|
|
fa1ebd9d2f | ||
|
|
eed5d02069 | ||
|
|
3650d91045 | ||
|
|
6c7d08cacb | ||
|
|
bb1c40f222 | ||
|
|
bfb117d0e0 | ||
|
|
b31c1022c3 | ||
|
|
a5851ca31c | ||
|
|
77bf5c15bb | ||
|
|
d26b7a1a12 | ||
|
|
595133463e | ||
|
|
6155f9ff9e | ||
|
|
7be87c8048 | ||
|
|
9868c3bfe3 | ||
|
|
8b299d0bac | ||
|
|
a44bfb4658 | ||
|
|
96fb5f6881 | ||
|
|
4109ea5324 | ||
|
|
f6c2ee5040 | ||
|
|
965753bf8b | ||
|
|
40c53ab95c | ||
|
|
aaa6211625 | ||
|
|
f6d770eac9 | ||
|
|
47cb61cd62 | ||
|
|
b0fdc8ae1c | ||
|
|
ed9b30efda | ||
|
|
168e5eeff0 | ||
|
|
7acaa86bdf | ||
|
|
96c0393fe7 | ||
|
|
403f795c5e | ||
|
|
c0f88a083e | ||
|
|
542b182899 | ||
|
|
3f58c68c09 | ||
|
|
e50c7e5947 | ||
|
|
4a83700fe4 | ||
|
|
c9992914d6 | ||
|
|
c25f6d1f84 | ||
|
|
a53e1ccf08 | ||
|
|
1af9930951 | ||
|
|
c276c1cbee | ||
|
|
c619348f29 | ||
|
|
c6f96613fc | ||
|
|
258bf736da | ||
|
|
0d75c99476 | ||
|
|
323d409fb6 | ||
|
|
f251722f56 | ||
|
|
c9dc27afbb | ||
|
|
efd14ec0e4 | ||
|
|
21ee2b6251 | ||
|
|
82dd2d508f | ||
|
|
3f12a43e75 | ||
|
|
5a59f6e3b8 | ||
|
|
60b5aef16a | ||
|
|
0e8b5484d5 | ||
|
|
454506c83e | ||
|
|
8f6ab67376 | ||
|
|
5afcc7778f | ||
|
|
325e07d330 | ||
|
|
a016bdc159 | ||
|
|
a14f0b2864 | ||
|
|
721483318a |
@@ -1,9 +1,11 @@
|
||||
*
|
||||
!invokeai
|
||||
!pyproject.toml
|
||||
!uv.lock
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
**/dist
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
||||
**/*.egg-info
|
||||
|
||||
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@@ -97,6 +97,8 @@ jobs:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
build-args: |
|
||||
GPU_DRIVER=${{ matrix.gpu-driver }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
2
.github/workflows/build-installer.yml
vendored
2
.github/workflows/build-installer.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
|
||||
21
.github/workflows/python-checks.yml
vendored
21
.github/workflows/python-checks.yml
vendored
@@ -34,6 +34,9 @@ on:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
env:
|
||||
# uv requires a venv by default - but for this, we can simply use the system python
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
@@ -57,25 +60,19 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff==0.9.9
|
||||
shell: bash
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff check --output-format=github .
|
||||
run: uv tool run ruff@0.11.2 check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff format --check .
|
||||
run: uv tool run ruff@0.11.2 format --check .
|
||||
shell: bash
|
||||
|
||||
30
.github/workflows/python-tests.yml
vendored
30
.github/workflows/python-tests.yml
vendored
@@ -39,24 +39,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
- '3.12'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
os: ubuntu-24.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
@@ -70,6 +61,8 @@ jobs:
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
# https://github.com/nschloe/action-cached-lfs-checkout
|
||||
@@ -92,20 +85,25 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable ".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
|
||||
20
.github/workflows/typegen-checks.yml
vendored
20
.github/workflows/typegen-checks.yml
vendored
@@ -54,17 +54,25 @@ jobs:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: '3.11'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
python-version: '3.11'
|
||||
|
||||
- name: install python dependencies
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip3 install --use-pep517 --editable="."
|
||||
env:
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable .
|
||||
|
||||
- name: install frontend dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
@@ -77,7 +85,7 @@ jobs:
|
||||
|
||||
- name: generate schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: make frontend-typegen
|
||||
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
shell: bash
|
||||
|
||||
- name: compare files
|
||||
|
||||
@@ -1,77 +1,6 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
build-essential \
|
||||
git
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_PYTHON=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt && \
|
||||
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
|
||||
USER ubuntu
|
||||
|
||||
# Install python
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
uv python install ${PYTHON_VERSION}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync --no-install-project
|
||||
|
||||
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
|
||||
COPY invokeai invokeai
|
||||
COPY pyproject.toml .
|
||||
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync
|
||||
|
||||
|
||||
#### Build the Web UI ------------------------------------
|
||||
#### Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -85,69 +14,89 @@ RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
## Backend ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
FROM library/ubuntu:24.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev
|
||||
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cu124" \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
ARG GPU_DRIVER=cuda
|
||||
|
||||
# Install `uv` for package management
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv sync --frozen
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
|
||||
=== "Windows"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
@@ -37,7 +37,13 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -101,10 +107,25 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
ObjectSerializerDisk[torch.Tensor](
|
||||
output_folder / "tensors",
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
max_cache_size=0,
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
output_folder / "conditioning",
|
||||
safe_globals=[
|
||||
ConditioningFieldData,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
|
||||
@@ -96,6 +96,22 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
class ImageUploadEntry(BaseModel):
|
||||
image_dto: ImageDTO = Body(description="The image DTO")
|
||||
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
|
||||
|
||||
|
||||
@images_router.post("/", operation_id="create_image_upload_entry")
|
||||
async def create_image_upload_entry(
|
||||
width: int = Body(description="The width of the image"),
|
||||
height: int = Body(description="The height of the image"),
|
||||
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
|
||||
) -> ImageUploadEntry:
|
||||
"""Uploads an image from a URL, not implemented"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -27,7 +28,6 @@ import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
@@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
@@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
@@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
class InvocationRegistry:
|
||||
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
|
||||
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation types.
|
||||
|
||||
This is used to parse serialized invocations into the correct invocation class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_invocation_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation type adapter."""
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[type[BaseInvocation]] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in cls.get_invocation_classes()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in cls.get_invocation_classes())
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@classmethod
|
||||
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
|
||||
|
||||
This is used to parse serialized invocation outputs into the correct invocation output class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_output_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation output type adapter."""
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
@@ -453,8 +461,8 @@ def invocation(
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
|
||||
# Handle the case where an existing node is being clobbered by the one we are registering
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
|
||||
if invocation_type in InvocationRegistry.get_invocation_types():
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
assert clobbered_invocation is not None
|
||||
|
||||
@@ -539,8 +547,7 @@ def invocation(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
InvocationRegistry.register_invocation(cls)
|
||||
|
||||
return cls
|
||||
|
||||
@@ -565,7 +572,7 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
if output_type in InvocationRegistry.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
@@ -586,7 +593,7 @@ def invocation_output(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
InvocationRegistry.register_output(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@@ -1089,12 +1089,13 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
|
||||
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
|
||||
If the fade size is 0, the mask is returned as-is.
|
||||
"""
|
||||
|
||||
mask: ImageField = InputField(description="The mask to expand")
|
||||
@@ -1104,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.fade_size_px == 0:
|
||||
# If the fade size is 0, just return the mask as-is.
|
||||
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
np_mask = numpy.array(pil_mask)
|
||||
|
||||
# Threshold the mask to create a binary mask - 0 for black, 255 for white
|
||||
@@ -1141,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
coeffs = numpy.polyfit(x_control, y_control, 3)
|
||||
poly = numpy.poly1d(coeffs)
|
||||
|
||||
# Evaluate and clip the smooth mapping
|
||||
feather = numpy.clip(poly(d_norm), 0, 1)
|
||||
# Evaluate the polynomial
|
||||
feather = poly(d_norm)
|
||||
|
||||
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
|
||||
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
|
||||
# polynomial fit, which is a best approximation of the control points but not an exact match.
|
||||
|
||||
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
|
||||
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
|
||||
|
||||
# Force pixels at or beyond the fade distance to exactly 1.0
|
||||
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
|
||||
|
||||
# Clip any other values to ensure they're in the valid range [0,1]
|
||||
feather = numpy.clip(feather, 0, 1)
|
||||
|
||||
# Build final image.
|
||||
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
|
||||
|
||||
@@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param safe_globals: A list of types to be added to the safe globals for torch serialization
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Path,
|
||||
safe_globals: list[type],
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
@@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationRegistry,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_output_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
@@ -4,7 +4,10 @@ from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
InvocationRegistry,
|
||||
UIConfigBase,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
@@ -56,14 +59,14 @@ def get_openapi_func(
|
||||
invocation_output_map_required: list[str] = []
|
||||
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in BaseInvocationOutput.get_outputs():
|
||||
for output in InvocationRegistry.get_output_classes():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||
# property, so we'll just do it all manually.
|
||||
for invocation in BaseInvocation.get_invocations():
|
||||
for invocation in InvocationRegistry.get_invocation_classes():
|
||||
json_schema = invocation.model_json_schema(
|
||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
|
||||
23
invokeai/backend/flux/flux_state_dict_utils.py
Normal file
23
invokeai/backend/flux/flux_state_dict_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.legacy_probe import CkptType
|
||||
|
||||
|
||||
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
|
||||
"""Gets the in channels from the state dict."""
|
||||
|
||||
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
|
||||
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
|
||||
# - https://civitai.com/models/885098?modelVersionId=990775
|
||||
# - https://civitai.com/models/1018060?modelVersionId=1596255
|
||||
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
|
||||
|
||||
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
|
||||
|
||||
for key in keys:
|
||||
val = state_dict.get(key)
|
||||
if val is not None:
|
||||
return val.shape[1]
|
||||
|
||||
return None
|
||||
@@ -30,19 +30,18 @@ from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
FluxLoRAFormat,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
@@ -51,9 +50,8 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,68 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
self.name = path.name
|
||||
self.hash_algo = hash_algo
|
||||
|
||||
def hash(self):
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
weight_files.extend(list(self.path.glob("**/*.bin")))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
@staticmethod
|
||||
def load_state_dict(path: Path):
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
||||
elif path.suffix.endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
return state_dict
|
||||
|
||||
|
||||
class MatchSpeed(int, Enum):
|
||||
"""Represents the estimated runtime speed of a config's 'matches' method."""
|
||||
|
||||
@@ -233,16 +169,18 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
candidates = ModelConfigBase._USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
mod = ModelOnDisk(model_path, hash_algo)
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
except InvalidModelConfigException:
|
||||
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
|
||||
if not config_cls.matches(mod):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
|
||||
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
|
||||
continue
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
|
||||
@@ -282,12 +220,12 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
if "source_type" in overrides:
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
if "variant" in overrides:
|
||||
overrides["variant"] = ModelVariantType(overrides["variant"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
if not cls.matches(mod):
|
||||
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
|
||||
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(overrides)
|
||||
fields.update(overrides)
|
||||
@@ -344,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
@classmethod
|
||||
def flux_lora_format(cls, mod: ModelOnDisk):
|
||||
key = "FLUX_LORA_FORMAT"
|
||||
if key in mod.cache:
|
||||
return mod.cache[key]
|
||||
|
||||
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
||||
|
||||
sd = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(sd)
|
||||
mod.cache[key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if cls.flux_lora_format(mod):
|
||||
return BaseModelType.Flux
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model
|
||||
token_vector_length = lora_token_vector_length(state_dict)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown LoRA type")
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ABC, BaseModel):
|
||||
"""Base class for diffusers-style models."""
|
||||
@@ -359,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return False
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if type(key) is int:
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
return True
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(ABC, BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
@@ -387,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_file():
|
||||
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
return any(wf.exists() for wf in weight_files)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
@@ -563,7 +577,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.format_type == ModelFormat.Checkpoint:
|
||||
if mod.path.is_file():
|
||||
return False
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
@@ -564,7 +565,14 @@ class CheckpointProbeBase(ProbeBase):
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
|
||||
if base_type == BaseModelType.Flux:
|
||||
in_channels = state_dict["img_in.weight"].shape[1]
|
||||
in_channels = get_flux_in_channels_from_state_dict(state_dict)
|
||||
|
||||
if in_channels is None:
|
||||
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
|
||||
logger.warning(
|
||||
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
|
||||
)
|
||||
return ModelVariantType.Normal
|
||||
|
||||
# FLUX Model variant types are distinguished by input channels:
|
||||
# - Unquantized Dev and Schnell have in_channels=64
|
||||
|
||||
96
invokeai/backend/model_manager/model_on_disk.py
Normal file
96
invokeai/backend/model_manager/model_on_disk.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
|
||||
|
||||
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
self.name = path.name
|
||||
self.hash_algo = hash_algo
|
||||
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
|
||||
# This prevents redundant computations during matching and parsing
|
||||
self.cache = {"_CACHED_STATE_DICTS": {}}
|
||||
|
||||
def hash(self) -> str:
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self) -> int:
|
||||
if self.path.is_file():
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self) -> set[Path]:
|
||||
if self.path.is_file():
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.path.is_file():
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
weight_files.extend(list(self.path.glob("**/*.bin")))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
||||
sd_cache = self.cache["_CACHED_STATE_DICTS"]
|
||||
|
||||
if path in sd_cache:
|
||||
return sd_cache[path]
|
||||
|
||||
if not path:
|
||||
components = list(self.component_paths())
|
||||
match components:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
path = p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
assert isinstance(checkpoint, dict)
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
||||
elif path.suffix.endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
sd_cache[path] = state_dict
|
||||
return state_dict
|
||||
@@ -126,4 +126,13 @@ class ModelSourceType(str, Enum):
|
||||
HFRepoID = "hf_repo_id"
|
||||
|
||||
|
||||
class FluxLoRAFormat(str, Enum):
|
||||
"""Flux LoRA formats."""
|
||||
|
||||
Diffusers = "flux.diffusers"
|
||||
Kohya = "flux.kohya"
|
||||
OneTrainer = "flux.onetrainer"
|
||||
Control = "flux.control"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
24
invokeai/backend/patches/lora_conversions/formats.py
Normal file
24
invokeai/backend/patches/lora_conversions/formats.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
)
|
||||
|
||||
|
||||
def flux_format_from_state_dict(state_dict):
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
|
||||
return FluxLoRAFormat.OneTrainer
|
||||
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
|
||||
return FluxLoRAFormat.Diffusers
|
||||
elif is_state_dict_likely_flux_control(state_dict):
|
||||
return FluxLoRAFormat.Control
|
||||
else:
|
||||
return None
|
||||
@@ -69,6 +69,9 @@ class SD3ConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
|
||||
# invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be
|
||||
# unable to deserialize the object and will raise an error.
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
|
||||
@@ -62,7 +62,7 @@
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.6.1",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.4.2",
|
||||
"@xyflow/react": "^12.5.3",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.0.0",
|
||||
@@ -162,5 +162,6 @@
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
}
|
||||
},
|
||||
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
|
||||
}
|
||||
|
||||
48
invokeai/frontend/web/pnpm-lock.yaml
generated
48
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -36,8 +36,8 @@ dependencies:
|
||||
specifier: ^1.3.0
|
||||
version: 1.3.0
|
||||
'@xyflow/react':
|
||||
specifier: ^12.4.2
|
||||
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^12.5.3
|
||||
version: 12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
async-mutex:
|
||||
specifier: ^0.5.0
|
||||
version: 0.5.0
|
||||
@@ -3323,7 +3323,7 @@ packages:
|
||||
/@types/d3-drag@3.0.7:
|
||||
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-interpolate@3.0.4:
|
||||
@@ -3332,21 +3332,21 @@ packages:
|
||||
'@types/d3-color': 3.1.3
|
||||
dev: false
|
||||
|
||||
/@types/d3-selection@3.0.10:
|
||||
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
|
||||
/@types/d3-selection@3.0.11:
|
||||
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
|
||||
dev: false
|
||||
|
||||
/@types/d3-transition@3.0.8:
|
||||
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
|
||||
/@types/d3-transition@3.0.9:
|
||||
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-zoom@3.0.8:
|
||||
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
|
||||
dependencies:
|
||||
'@types/d3-interpolate': 3.0.4
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/diff-match-patch@1.0.36:
|
||||
@@ -3951,28 +3951,28 @@ packages:
|
||||
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
|
||||
dev: false
|
||||
|
||||
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
|
||||
/@xyflow/react@12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-saovy/aQRoW8qQoIqMFUtmC3F6oEV7n6+J1pVbhSG45NI/hOFvK0qozsIPKqX5Va6lGQnkl/o53NHLja3NiweQ==}
|
||||
peerDependencies:
|
||||
react: '>=17'
|
||||
react-dom: '>=17'
|
||||
dependencies:
|
||||
'@xyflow/system': 0.0.50
|
||||
'@xyflow/system': 0.0.53
|
||||
classcat: 5.0.5
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
|
||||
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- immer
|
||||
dev: false
|
||||
|
||||
/@xyflow/system@0.0.50:
|
||||
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
|
||||
/@xyflow/system@0.0.53:
|
||||
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
|
||||
dependencies:
|
||||
'@types/d3-drag': 3.0.7
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-transition': 3.0.8
|
||||
'@types/d3-selection': 3.0.11
|
||||
'@types/d3-transition': 3.0.9
|
||||
'@types/d3-zoom': 3.0.8
|
||||
d3-drag: 3.0.0
|
||||
d3-selection: 3.0.0
|
||||
@@ -9123,6 +9123,14 @@ packages:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/use-sync-external-store@1.5.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/util-deprecate@1.0.2:
|
||||
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
||||
dev: true
|
||||
@@ -9567,8 +9575,8 @@ packages:
|
||||
/zod@3.23.8:
|
||||
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
|
||||
|
||||
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
|
||||
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
|
||||
engines: {node: '>=12.7.0'}
|
||||
peerDependencies:
|
||||
'@types/react': '>=16.8'
|
||||
@@ -9584,5 +9592,5 @@ packages:
|
||||
dependencies:
|
||||
'@types/react': 18.3.11
|
||||
react: 18.3.1
|
||||
use-sync-external-store: 1.2.2(react@18.3.1)
|
||||
use-sync-external-store: 1.5.0(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
@@ -116,7 +116,10 @@
|
||||
"combinatorial": "Kombinatorisch",
|
||||
"saveChanges": "Änderungen speichern",
|
||||
"error_withCount_one": "{{count}} Fehler",
|
||||
"error_withCount_other": "{{count}} Fehler"
|
||||
"error_withCount_other": "{{count}} Fehler",
|
||||
"value": "Wert",
|
||||
"label": "Label",
|
||||
"systemInformation": "Systeminformationen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -695,7 +698,10 @@
|
||||
"guidance": "Führung",
|
||||
"coherenceMode": "Modus",
|
||||
"recallMetadata": "Metadaten abrufen",
|
||||
"gaussianBlur": "Gaußsche Unschärfe"
|
||||
"gaussianBlur": "Gaußsche Unschärfe",
|
||||
"sendToUpscale": "An Hochskalieren senden",
|
||||
"useCpuNoise": "CPU-Rauschen verwenden",
|
||||
"sendToCanvas": "An Leinwand senden"
|
||||
},
|
||||
"settings": {
|
||||
"displayInProgress": "Zwischenbilder anzeigen",
|
||||
@@ -1328,7 +1334,8 @@
|
||||
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
|
||||
"loadingTemplates": "Lade {{name}}",
|
||||
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten",
|
||||
"showEdgeLabelsHelp": "Beschriftungen an Kanten anzeigen, um die verknüpften Knoten zu kennzeichnen"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -115,7 +115,8 @@
|
||||
"error_withCount_many": "{{count}} errori",
|
||||
"error_withCount_other": "{{count}} errori",
|
||||
"value": "Valore",
|
||||
"label": "Etichetta"
|
||||
"label": "Etichetta",
|
||||
"systemInformation": "Informazioni di sistema"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -715,7 +716,8 @@
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
|
||||
"collectionEmpty": "raccolta vuota"
|
||||
"collectionEmpty": "raccolta vuota",
|
||||
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -2365,8 +2367,9 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
|
||||
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela.",
|
||||
"LLaVA OneVision VLLM: supporto beta nei flussi di lavoro."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -237,7 +237,10 @@
|
||||
"row": "Hàng",
|
||||
"board": "Bảng",
|
||||
"saveChanges": "Lưu Thay Đổi",
|
||||
"error_withCount_other": "{{count}} lỗi"
|
||||
"error_withCount_other": "{{count}} lỗi",
|
||||
"value": "Giá Trị",
|
||||
"label": "Nhãn Tên",
|
||||
"systemInformation": "Thông Tin Hệ Thống"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -2300,7 +2303,10 @@
|
||||
"minimum": "Tối Thiểu",
|
||||
"maximum": "Tối Đa",
|
||||
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)",
|
||||
"resetOptions": "Tải Lại Lựa Chọn",
|
||||
"addOption": "Thêm Lựa Chọn",
|
||||
"dropdown": "Danh Sách Thả Xuống"
|
||||
},
|
||||
"yourWorkflows": "Workflow Của Bạn",
|
||||
"browseWorkflows": "Khám Phá Workflow",
|
||||
@@ -2316,7 +2322,8 @@
|
||||
"view": "Xem",
|
||||
"deselectAll": "Huỷ Chọn Tất Cả",
|
||||
"noRecentWorkflows": "Không Có Workflows Gần Đây",
|
||||
"recommended": "Có Thể Bạn Sẽ Cần"
|
||||
"recommended": "Có Thể Bạn Sẽ Cần",
|
||||
"emptyStringPlaceholder": "<xâu ký tự trống>"
|
||||
},
|
||||
"upscaling": {
|
||||
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
|
||||
@@ -2352,8 +2359,9 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
|
||||
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
|
||||
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
|
||||
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas.",
|
||||
"LLaVA OneVision VLLM: Hỗ trợ phiên bản Beta trong Workflow."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { imageUploadedClientSide } from 'features/gallery/store/actions';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -8,7 +10,8 @@ import { t } from 'i18next';
|
||||
import { omit } from 'lodash-es';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
const log = logger('gallery');
|
||||
|
||||
/**
|
||||
@@ -34,19 +37,56 @@ let lastUploadedToastTimeout: number | null = null;
|
||||
|
||||
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
|
||||
matcher: isAnyOf(imagesApi.endpoints.uploadImage.matchFulfilled, imageUploadedClientSide),
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const imageDTO = action.payload;
|
||||
let imageDTO: ImageDTO;
|
||||
let silent;
|
||||
let isFirstUploadOfBatch = true;
|
||||
|
||||
if (imageUploadedClientSide.match(action)) {
|
||||
imageDTO = action.payload.imageDTO;
|
||||
silent = action.payload.silent;
|
||||
isFirstUploadOfBatch = action.payload.isFirstUploadOfBatch;
|
||||
} else if (imagesApi.endpoints.uploadImage.matchFulfilled(action)) {
|
||||
imageDTO = action.payload;
|
||||
silent = action.meta.arg.originalArgs.silent;
|
||||
isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
if (silent || imageDTO.is_intermediate) {
|
||||
// If the image is silent or intermediate, we don't want to show a toast
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageUploadedClientSide.match(action)) {
|
||||
const categories = getCategories(imageDTO);
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: boardId,
|
||||
categories,
|
||||
}),
|
||||
},
|
||||
{
|
||||
type: 'Board',
|
||||
id: boardId,
|
||||
},
|
||||
{
|
||||
type: 'BoardImagesTotal',
|
||||
id: boardId,
|
||||
},
|
||||
])
|
||||
);
|
||||
}
|
||||
const state = getState();
|
||||
|
||||
log.debug({ imageDTO }, 'Image uploaded');
|
||||
|
||||
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
|
||||
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
|
||||
// like toasts and switching the gallery view
|
||||
return;
|
||||
}
|
||||
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
|
||||
const DEFAULT_UPLOADED_TOAST = {
|
||||
@@ -80,7 +120,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
*
|
||||
* Default to true to not require _all_ image upload handlers to set this value
|
||||
*/
|
||||
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
|
||||
|
||||
if (isFirstUploadOfBatch) {
|
||||
dispatch(boardIdSelected({ boardId }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
|
||||
@@ -73,6 +73,7 @@ export type AppConfig = {
|
||||
maxUpscaleDimension?: number;
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
@@ -81,7 +82,6 @@ export type AppConfig = {
|
||||
metadataFetchDebounce?: number;
|
||||
workflowFetchDebounce?: number;
|
||||
isLocal?: boolean;
|
||||
maxImageUploadCount?: number;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: string[];
|
||||
|
||||
121
invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts
Normal file
121
invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageUploadedClientSide } from 'features/gallery/store/actions';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { useCallback } from 'react';
|
||||
import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type PresignedUrlResponse = {
|
||||
fullUrl: string;
|
||||
thumbnailUrl: string;
|
||||
};
|
||||
|
||||
const isPresignedUrlResponse = (response: unknown): response is PresignedUrlResponse => {
|
||||
return typeof response === 'object' && response !== null && 'fullUrl' in response && 'thumbnailUrl' in response;
|
||||
};
|
||||
|
||||
export const useClientSideUpload = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const authToken = useStore($authToken);
|
||||
const [createImageUploadEntry] = useCreateImageUploadEntryMutation();
|
||||
|
||||
const clientSideUpload = useCallback(
|
||||
async (file: File, i: number): Promise<ImageDTO> => {
|
||||
const image = new Image();
|
||||
const objectURL = URL.createObjectURL(file);
|
||||
image.src = objectURL;
|
||||
let width = 0;
|
||||
let height = 0;
|
||||
let thumbnail: Blob | undefined;
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
image.onload = () => {
|
||||
width = image.naturalWidth;
|
||||
height = image.naturalHeight;
|
||||
|
||||
// Calculate thumbnail dimensions maintaining aspect ratio
|
||||
let thumbWidth = width;
|
||||
let thumbHeight = height;
|
||||
if (width > height && width > 256) {
|
||||
thumbWidth = 256;
|
||||
thumbHeight = Math.round((height * 256) / width);
|
||||
} else if (height > 256) {
|
||||
thumbHeight = 256;
|
||||
thumbWidth = Math.round((width * 256) / height);
|
||||
}
|
||||
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = thumbWidth;
|
||||
canvas.height = thumbHeight;
|
||||
const ctx = canvas.getContext('2d');
|
||||
ctx?.drawImage(image, 0, 0, thumbWidth, thumbHeight);
|
||||
|
||||
canvas.toBlob(
|
||||
(blob) => {
|
||||
if (blob) {
|
||||
thumbnail = blob;
|
||||
// Clean up resources
|
||||
URL.revokeObjectURL(objectURL);
|
||||
image.src = ''; // Clear image source
|
||||
image.remove(); // Remove the image element
|
||||
canvas.width = 0; // Clear canvas
|
||||
canvas.height = 0;
|
||||
resolve();
|
||||
}
|
||||
},
|
||||
'image/webp',
|
||||
0.8
|
||||
);
|
||||
};
|
||||
|
||||
// Handle load errors
|
||||
image.onerror = () => {
|
||||
URL.revokeObjectURL(objectURL);
|
||||
image.remove();
|
||||
resolve();
|
||||
};
|
||||
});
|
||||
const { presigned_url, image_dto } = await createImageUploadEntry({
|
||||
width,
|
||||
height,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}).unwrap();
|
||||
|
||||
const response = await fetch(presigned_url, {
|
||||
method: 'GET',
|
||||
...(authToken && {
|
||||
headers: {
|
||||
Authorization: `Bearer ${authToken}`,
|
||||
},
|
||||
}),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (!isPresignedUrlResponse(response)) {
|
||||
throw new Error('Invalid response');
|
||||
}
|
||||
|
||||
const fullUrl = response.fullUrl;
|
||||
const thumbnailUrl = response.thumbnailUrl;
|
||||
|
||||
await fetch(fullUrl, {
|
||||
method: 'PUT',
|
||||
body: file,
|
||||
});
|
||||
|
||||
await fetch(thumbnailUrl, {
|
||||
method: 'PUT',
|
||||
body: thumbnail,
|
||||
});
|
||||
|
||||
dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 }));
|
||||
|
||||
return image_dto;
|
||||
},
|
||||
[autoAddBoardId, authToken, createImageUploadEntry, dispatch]
|
||||
);
|
||||
|
||||
return clientSideUpload;
|
||||
};
|
||||
@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import type { FileRejection } from 'react-dropzone';
|
||||
@@ -15,6 +15,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import type { SetOptional } from 'type-fest';
|
||||
|
||||
import { useClientSideUpload } from './useClientSideUpload';
|
||||
type UseImageUploadButtonArgs =
|
||||
| {
|
||||
isDisabled?: boolean;
|
||||
@@ -50,51 +51,65 @@ const log = logger('gallery');
|
||||
*/
|
||||
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
|
||||
const [uploadImage, request] = useUploadImageMutation();
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const clientSideUpload = useClientSideUpload();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
async (files: File[]) => {
|
||||
if (!allowMultiple) {
|
||||
if (files.length > 1) {
|
||||
log.warn('Multiple files dropped but only one allowed');
|
||||
return;
|
||||
}
|
||||
if (files.length === 0) {
|
||||
// Should never happen
|
||||
log.warn('No files dropped');
|
||||
return;
|
||||
}
|
||||
const file = files[0];
|
||||
assert(file !== undefined); // should never happen
|
||||
const imageDTO = await uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
silent: true,
|
||||
}).unwrap();
|
||||
if (onUpload) {
|
||||
onUpload(imageDTO);
|
||||
}
|
||||
} else {
|
||||
const imageDTOs = await uploadImages(
|
||||
files.map((file, i) => ({
|
||||
try {
|
||||
if (!allowMultiple) {
|
||||
if (files.length > 1) {
|
||||
log.warn('Multiple files dropped but only one allowed');
|
||||
return;
|
||||
}
|
||||
if (files.length === 0) {
|
||||
// Should never happen
|
||||
log.warn('No files dropped');
|
||||
return;
|
||||
}
|
||||
const file = files[0];
|
||||
assert(file !== undefined); // should never happen
|
||||
const imageDTO = await uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
silent: false,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}))
|
||||
);
|
||||
if (onUpload) {
|
||||
onUpload(imageDTOs);
|
||||
silent: true,
|
||||
}).unwrap();
|
||||
if (onUpload) {
|
||||
onUpload(imageDTO);
|
||||
}
|
||||
} else {
|
||||
let imageDTOs: ImageDTO[] = [];
|
||||
if (isClientSideUploadEnabled) {
|
||||
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
|
||||
} else {
|
||||
imageDTOs = await uploadImages(
|
||||
files.map((file, i) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
silent: false,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}))
|
||||
);
|
||||
}
|
||||
if (onUpload) {
|
||||
onUpload(imageDTOs);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.imageUploadFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
},
|
||||
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
|
||||
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
|
||||
);
|
||||
|
||||
const onDropRejected = useCallback(
|
||||
@@ -105,10 +120,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
file: rejection.file.path,
|
||||
}));
|
||||
log.error({ errors }, 'Invalid upload');
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
const description = t('toast.uploadFailedInvalidUploadDesc');
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
@@ -120,7 +132,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
return;
|
||||
}
|
||||
},
|
||||
[maxImageUploadCount, t]
|
||||
[t]
|
||||
);
|
||||
|
||||
const {
|
||||
@@ -137,8 +149,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
onDropRejected,
|
||||
disabled: isDisabled,
|
||||
noDrag: true,
|
||||
multiple: allowMultiple && (maxImageUploadCount === undefined || maxImageUploadCount > 1),
|
||||
maxFiles: maxImageUploadCount,
|
||||
multiple: allowMultiple,
|
||||
});
|
||||
|
||||
return { getUploadButtonProps, getUploadInputProps, openUploader, request };
|
||||
|
||||
@@ -14,8 +14,9 @@ import WavyLine from 'common/components/WavyLine';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
|
||||
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||
|
||||
const selectHasRasterLayersWithContent = createSelector(
|
||||
selectActiveRasterLayerEntities,
|
||||
@@ -26,6 +27,7 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
|
||||
const selectedModelConfig = useSelectedModelConfig();
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
@@ -39,8 +41,24 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
|
||||
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
if (!hasRasterLayersWithContent) {
|
||||
// Denoising strength does nothing if there are no raster layers w/ content
|
||||
return true;
|
||||
}
|
||||
if (
|
||||
selectedModelConfig?.type === 'main' &&
|
||||
selectedModelConfig?.base === 'flux' &&
|
||||
selectedModelConfig.variant === 'inpaint'
|
||||
) {
|
||||
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [hasRasterLayersWithContent, selectedModelConfig]);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
|
||||
<FormControl isDisabled={isDisabled} p={1} justifyContent="space-between" h={8}>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
@@ -49,7 +67,7 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
|
||||
)}
|
||||
</Flex>
|
||||
{hasRasterLayersWithContent ? (
|
||||
{!isDisabled ? (
|
||||
<>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
|
||||
@@ -8,12 +8,13 @@ import { useStore } from '@nanostores/react';
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $focusedRegion } from 'common/hooks/focus';
|
||||
import { useClientSideUpload } from 'common/hooks/useClientSideUpload';
|
||||
import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
|
||||
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
|
||||
import type { DndTargetState } from 'features/dnd/types';
|
||||
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
@@ -53,13 +54,6 @@ const zUploadFile = z
|
||||
(file) => ({ message: `File extension .${file.name.split('.').at(-1)} is not supported` })
|
||||
);
|
||||
|
||||
const getFilesSchema = (max?: number) => {
|
||||
if (max === undefined) {
|
||||
return z.array(zUploadFile);
|
||||
}
|
||||
return z.array(zUploadFile).max(max);
|
||||
};
|
||||
|
||||
const sx = {
|
||||
position: 'absolute',
|
||||
top: 2,
|
||||
@@ -74,22 +68,19 @@ const sx = {
|
||||
export const FullscreenDropzone = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const [dndState, setDndState] = useState<DndTargetState>('idle');
|
||||
const activeTab = useAppSelector(selectActiveTab);
|
||||
const isImageViewerOpen = useStore($imageViewer);
|
||||
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
|
||||
const clientSideUpload = useClientSideUpload();
|
||||
|
||||
const validateAndUploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
async (files: File[]) => {
|
||||
const { getState } = getStore();
|
||||
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
const parseResult = z.array(zUploadFile).safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
const description = t('toast.uploadFailedInvalidUploadDesc');
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
@@ -118,17 +109,23 @@ export const FullscreenDropzone = memo(() => {
|
||||
|
||||
const autoAddBoardId = selectAutoAddBoardId(getState());
|
||||
|
||||
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}));
|
||||
if (isClientSideUploadEnabled) {
|
||||
for (const [i, file] of files.entries()) {
|
||||
await clientSideUpload(file, i);
|
||||
}
|
||||
} else {
|
||||
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}));
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
uploadImages(uploadArgs);
|
||||
}
|
||||
},
|
||||
[activeTab, isImageViewerOpen, maxImageUploadCount, t]
|
||||
[activeTab, isImageViewerOpen, t, isClientSideUploadEnabled, clientSideUpload]
|
||||
);
|
||||
|
||||
const onPaste = useCallback(
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { t } from 'i18next';
|
||||
import { useMemo } from 'react';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
|
||||
export const GalleryUploadButton = () => {
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const uploadOptions = useMemo(() => ({ allowMultiple: maxImageUploadCount !== 1 }), [maxImageUploadCount]);
|
||||
const uploadApi = useImageUploadButton(uploadOptions);
|
||||
const uploadApi = useImageUploadButton({ allowMultiple: true });
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
size="sm"
|
||||
alignSelf="stretch"
|
||||
variant="link"
|
||||
aria-label={
|
||||
maxImageUploadCount === undefined || maxImageUploadCount > 1
|
||||
? t('accessibility.uploadImages')
|
||||
: t('accessibility.uploadImage')
|
||||
}
|
||||
tooltip={
|
||||
maxImageUploadCount === undefined || maxImageUploadCount > 1
|
||||
? t('accessibility.uploadImages')
|
||||
: t('accessibility.uploadImage')
|
||||
}
|
||||
aria-label={t('accessibility.uploadImages')}
|
||||
tooltip={t('accessibility.uploadImages')}
|
||||
icon={<PiUploadBold />}
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
|
||||
@@ -49,7 +49,11 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavLeft',
|
||||
category: 'gallery',
|
||||
callback: () => {
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev('arrow');
|
||||
return;
|
||||
@@ -71,7 +75,11 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavRight',
|
||||
category: 'gallery',
|
||||
callback: () => {
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext('arrow');
|
||||
return;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
||||
|
||||
@@ -7,3 +8,9 @@ export const imageDownloaded = createAction('gallery/imageDownloaded');
|
||||
export const imageCopiedToClipboard = createAction('gallery/imageCopiedToClipboard');
|
||||
|
||||
export const imageOpenedInNewTab = createAction('gallery/imageOpenedInNewTab');
|
||||
|
||||
export const imageUploadedClientSide = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
silent: boolean;
|
||||
isFirstUploadOfBatch: boolean;
|
||||
}>('gallery/imageUploadedClientSide');
|
||||
|
||||
@@ -470,31 +470,8 @@ export const nodesSlice = createSlice({
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
const { nodes, edges } = action.payload;
|
||||
|
||||
const changes: NodeChange<AnyNode>[] = [];
|
||||
for (const node of nodes) {
|
||||
if (node.type === 'notes') {
|
||||
changes.push({
|
||||
type: 'add',
|
||||
item: {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
...node,
|
||||
},
|
||||
});
|
||||
} else if (node.type === 'invocation') {
|
||||
changes.push({
|
||||
type: 'add',
|
||||
item: {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
...node,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
state.nodes = applyNodeChanges<AnyNode>(changes, []);
|
||||
state.edges = applyEdgeChanges(
|
||||
edges.map((edge) => ({ type: 'add', item: edge })),
|
||||
[]
|
||||
);
|
||||
state.nodes = nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node }));
|
||||
state.edges = edges;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -79,4 +79,4 @@ export const isInvocationOutputSchemaObject = (
|
||||
|
||||
export const isInvocationFieldSchema = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
||||
): obj is InvocationFieldSchema => 'field_kind' in obj;
|
||||
|
||||
@@ -148,7 +148,11 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
}
|
||||
}
|
||||
}
|
||||
edges.forEach((edge, i) => {
|
||||
|
||||
// Stash invalid edges here to be deleted later
|
||||
const edgesToDelete = new Set<string>();
|
||||
|
||||
for (const edge of edges) {
|
||||
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
|
||||
const sourceNode = nodes.find(({ id }) => id === edge.source);
|
||||
const targetNode = nodes.find(({ id }) => id === edge.target);
|
||||
@@ -215,8 +219,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
}
|
||||
|
||||
if (issues.length) {
|
||||
// This edge has some issues. Remove it.
|
||||
delete edges[i];
|
||||
edgesToDelete.add(edge.id);
|
||||
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
|
||||
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
|
||||
warnings.push({
|
||||
@@ -225,7 +228,10 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
data: edge,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Remove invalid edges
|
||||
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
|
||||
|
||||
// Migrated exposed fields to form elements if they exist and the form does not
|
||||
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!
|
||||
|
||||
@@ -20,6 +20,7 @@ const initialConfigState: AppConfig = {
|
||||
shouldFetchMetadataFromApi: false,
|
||||
allowPrivateBoards: false,
|
||||
allowPrivateStylePresets: false,
|
||||
allowClientSideUpload: false,
|
||||
disabledTabs: [],
|
||||
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
|
||||
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
|
||||
@@ -218,6 +219,5 @@ export const selectWorkflowFetchDebounce = createConfigSelector((config) => conf
|
||||
export const selectMetadataFetchDebounce = createConfigSelector((config) => config.metadataFetchDebounce ?? 300);
|
||||
|
||||
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
|
||||
export const selectMaxImageUploadCount = createConfigSelector((config) => config.maxImageUploadCount);
|
||||
|
||||
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
|
||||
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||
|
||||
@@ -7,6 +7,8 @@ import type {
|
||||
DeleteBoardResult,
|
||||
GraphAndWorkflowResponse,
|
||||
ImageDTO,
|
||||
ImageUploadEntryRequest,
|
||||
ImageUploadEntryResponse,
|
||||
ListImagesArgs,
|
||||
ListImagesResponse,
|
||||
UploadImageArg,
|
||||
@@ -287,6 +289,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
},
|
||||
};
|
||||
},
|
||||
|
||||
invalidatesTags: (result) => {
|
||||
if (!result || result.is_intermediate) {
|
||||
// Don't add it to anything
|
||||
@@ -314,7 +317,13 @@ export const imagesApi = api.injectEndpoints({
|
||||
];
|
||||
},
|
||||
}),
|
||||
|
||||
createImageUploadEntry: build.mutation<ImageUploadEntryResponse, ImageUploadEntryRequest>({
|
||||
query: ({ width, height, board_id }) => ({
|
||||
url: buildImagesUrl(),
|
||||
method: 'POST',
|
||||
body: { width, height, board_id },
|
||||
}),
|
||||
}),
|
||||
deleteBoard: build.mutation<DeleteBoardResult, string>({
|
||||
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
|
||||
invalidatesTags: () => [
|
||||
@@ -549,6 +558,7 @@ export const {
|
||||
useGetImageWorkflowQuery,
|
||||
useLazyGetImageWorkflowQuery,
|
||||
useUploadImageMutation,
|
||||
useCreateImageUploadEntryMutation,
|
||||
useClearIntermediatesMutation,
|
||||
useAddImagesToBoardMutation,
|
||||
useRemoveImagesFromBoardMutation,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -354,3 +354,6 @@ export type UploadImageArg = {
|
||||
*/
|
||||
isFirstUploadOfBatch?: boolean;
|
||||
};
|
||||
|
||||
export type ImageUploadEntryResponse = S['ImageUploadEntry'];
|
||||
export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json'];
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.9.0"
|
||||
__version__ = "5.10.0dev1"
|
||||
|
||||
14
pins.json
Normal file
14
pins.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"python": "3.11",
|
||||
"torchIndexUrl": {
|
||||
"win32": {
|
||||
"cuda": "https://download.pytorch.org/whl/cu126"
|
||||
},
|
||||
"linux": {
|
||||
"cpu": "https://download.pytorch.org/whl/cpu",
|
||||
"rocm": "https://download.pytorch.org/whl/rocm6.2.4",
|
||||
"cuda": "https://download.pytorch.org/whl/cu126"
|
||||
},
|
||||
"darwin": {}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "InvokeAI"
|
||||
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
|
||||
requires-python = ">=3.10, <3.12"
|
||||
requires-python = ">=3.10, <3.13"
|
||||
readme = { content-type = "text/markdown", file = "README.md" }
|
||||
keywords = ["stable-diffusion", "AI"]
|
||||
dynamic = ["version"]
|
||||
@@ -33,12 +33,12 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==1.0.1",
|
||||
"bitsandbytes==0.45.0; sys_platform!='darwin'",
|
||||
"accelerate",
|
||||
"bitsandbytes; sys_platform!='darwin'",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.31.0",
|
||||
"diffusers[torch]",
|
||||
"gguf==0.10.0",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
|
||||
@@ -46,26 +46,27 @@ dependencies = [
|
||||
"onnx==1.16.1",
|
||||
"onnxruntime==1.19.2",
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
"pytorch-lightning",
|
||||
"safetensors",
|
||||
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
|
||||
"timm~=1.0.0",
|
||||
"torch~=2.6.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
|
||||
"torchmetrics",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers==4.46.3",
|
||||
"timm~=1.0.0", # controlnet-aux depends on a version of timm that breaks LLaVA. explicitly pinning `timm` here, which results in a downgrade of `controlnet-aux`. https://github.com/huggingface/controlnet_aux/pull/101. If this poses a problem, we need to decide between newer `controlnet_aux` and LLaVA, OR we can fork `controlnet_aux` and update the pin.
|
||||
"transformers",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.1",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.26.1",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
"uvicorn[standard]==0.28.0",
|
||||
"fastapi-events",
|
||||
"fastapi",
|
||||
"huggingface-hub",
|
||||
"pydantic-settings",
|
||||
"pydantic",
|
||||
"python-socketio",
|
||||
"uvicorn[standard]",
|
||||
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
"albumentations",
|
||||
@@ -90,11 +91,8 @@ dependencies = [
|
||||
"pyreadline3",
|
||||
"python-multipart",
|
||||
"requests",
|
||||
"rich~=13.3",
|
||||
"scikit-image",
|
||||
"semver~=3.0.1",
|
||||
"test-tube",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
"humanize==4.12.1",
|
||||
]
|
||||
|
||||
@@ -117,7 +115,7 @@ dependencies = [
|
||||
]
|
||||
"dev" = ["jurigged", "pudb", "snakeviz", "gprof2dot"]
|
||||
"test" = [
|
||||
"ruff~=0.9.9",
|
||||
"ruff~=0.11.2",
|
||||
"ruff-lsp~=0.0.62",
|
||||
"mypy",
|
||||
"pre-commit",
|
||||
|
||||
@@ -22,9 +22,8 @@ from pathlib import Path
|
||||
import humanize
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelOnDisk
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat
|
||||
|
||||
|
||||
def strip(v):
|
||||
@@ -63,7 +62,7 @@ def load_stripped_model(path: Path, *args, **kwargs):
|
||||
|
||||
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
|
||||
original = ModelOnDisk(original_model_path)
|
||||
if original.format_type == ModelFormat.Checkpoint:
|
||||
if original.path.is_file():
|
||||
shutil.copy2(original.path, stripped_model_path)
|
||||
else:
|
||||
shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True)
|
||||
@@ -71,7 +70,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
|
||||
print(f"Created clone of {original.name} at {stripped.path}")
|
||||
|
||||
for component_path in stripped.component_paths():
|
||||
original_state_dict = ModelOnDisk.load_state_dict(component_path)
|
||||
original_state_dict = stripped.load_state_dict(component_path)
|
||||
stripped_state_dict = strip(original_state_dict) # type: ignore
|
||||
with open(component_path, "w") as f:
|
||||
json.dump(stripped_state_dict, f, indent=4)
|
||||
|
||||
@@ -24,7 +24,7 @@ from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
|
||||
sd = {k: None for k in sd_shapes}
|
||||
sd = dict.fromkeys(sd_shapes)
|
||||
assert is_state_dict_xlabs_controlnet(sd) == expected
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expecte
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
|
||||
sd = {k: None for k in sd_keys}
|
||||
sd = dict.fromkeys(sd_keys)
|
||||
assert is_state_dict_instantx_controlnet(sd) == expected
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xl
|
||||
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
|
||||
def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]):
|
||||
# Construct a dummy state_dict.
|
||||
sd = {k: None for k in sd_shapes}
|
||||
sd = dict.fromkeys(sd_shapes)
|
||||
|
||||
assert is_state_dict_xlabs_ip_adapter(sd)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.baseinvocation import InvocationRegistry
|
||||
from invokeai.app.services.config.config_default import (
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
@@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):
|
||||
|
||||
# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
|
||||
# subsequent graph validations
|
||||
BaseInvocation.invalidate_typeadapter()
|
||||
InvocationRegistry.invalidate_invocation_typeadapter()
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
|
||||
@@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir):
|
||||
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}})
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
all_invocations = InvocationRegistry.get_invocation_classes()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
|
||||
@@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):
|
||||
|
||||
# Reset the config so that it doesn't affect other tests
|
||||
get_config.cache_clear()
|
||||
BaseInvocation.invalidate_typeadapter()
|
||||
InvocationRegistry.invalidate_invocation_typeadapter()
|
||||
|
||||
@@ -17,7 +17,6 @@ from invokeai.backend.model_manager.config import (
|
||||
MainDiffusersConfig,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelOnDisk,
|
||||
get_model_discriminator_value,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import (
|
||||
@@ -27,6 +26,7 @@ from invokeai.backend.model_manager.legacy_probe import (
|
||||
get_default_settings_control_adapters,
|
||||
get_default_settings_main,
|
||||
)
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5acefb3658338a4126736e2da02cfef5a9ce6e2469564a6c7994ae34e8ef2e8a
|
||||
size 192447
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:55aafd0f9b4ac2863361573b070320e13b800b2359a81a73878008bdffc3edfa
|
||||
size 201040
|
||||
@@ -21,16 +21,18 @@ def count_files(path: Path):
|
||||
|
||||
@pytest.fixture
|
||||
def obj_serializer(tmp_path: Path):
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fwd_cache(tmp_path: Path):
|
||||
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||
return ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass]), max_cache_size=2
|
||||
)
|
||||
|
||||
|
||||
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
assert obj_serializer._output_dir == tmp_path
|
||||
|
||||
|
||||
@@ -70,7 +72,7 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
|
||||
assert obj_serializer._base_output_dir == tmp_path
|
||||
assert obj_serializer._output_dir != tmp_path
|
||||
@@ -78,21 +80,21 @@ def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
del obj_serializer
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
@@ -102,19 +104,19 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
assert not tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=False)
|
||||
assert tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer_1.save(obj_1)
|
||||
obj_1_loaded = obj_serializer_1.load(obj_1_name)
|
||||
@@ -123,19 +125,19 @@ def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
assert obj_1_loaded.foo == "bar"
|
||||
assert obj_1_name.startswith("MockDataclass_")
|
||||
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path, safe_globals=[int])
|
||||
obj_2_name = obj_serializer_2.save(9001)
|
||||
assert obj_serializer_2._obj_class_name == "int"
|
||||
assert obj_serializer_2.load(obj_2_name) == 9001
|
||||
assert obj_2_name.startswith("int_")
|
||||
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path, safe_globals=[str])
|
||||
obj_3_name = obj_serializer_3.save("foo")
|
||||
assert obj_serializer_3._obj_class_name == "str"
|
||||
assert obj_serializer_3.load(obj_3_name) == "foo"
|
||||
assert obj_3_name.startswith("str_")
|
||||
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path, safe_globals=[torch.Tensor])
|
||||
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
|
||||
obj_4_loaded = obj_serializer_4.load(obj_4_name)
|
||||
assert obj_serializer_4._obj_class_name == "Tensor"
|
||||
|
||||
Reference in New Issue
Block a user