mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 15:27:55 -05:00
Compare commits
1 Commits
main
...
ryan/conse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
109cbb8532 |
@@ -1,11 +1,9 @@
|
||||
*
|
||||
!invokeai
|
||||
!pyproject.toml
|
||||
!uv.lock
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
**/dist
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
||||
**/*.egg-info
|
||||
@@ -1,5 +1,2 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
182580ff6970caed400be178c5b888514b75d7f2
|
||||
8e9d5c1187b0d36da80571ce4c8ba9b3a37b6c46
|
||||
99aac5870e1092b182e6c5f21abcaab6936a4ad1
|
||||
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -2,5 +2,4 @@
|
||||
# Only affects text files and ignores other file types.
|
||||
# For more info see: https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/
|
||||
* text=auto
|
||||
docker/** text eol=lf
|
||||
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
|
||||
docker/** text eol=lf
|
||||
33
.github/CODEOWNERS
vendored
33
.github/CODEOWNERS
vendored
@@ -1,31 +1,32 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
|
||||
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
|
||||
/installer/ @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/configs @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @hipsterusername @ebr
|
||||
/scripts/ @ebr @lstein @hipsterusername
|
||||
/installer/ @lstein @ebr @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @hipsterusername
|
||||
/invokeai/configs @lstein @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @hipsterusername
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername
|
||||
|
||||
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -21,20 +21,6 @@ body:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: install_method
|
||||
attributes:
|
||||
label: Install method
|
||||
description: How did you install Invoke?
|
||||
multiple: false
|
||||
options:
|
||||
- "Invoke's Launcher"
|
||||
- 'Stability Matrix'
|
||||
- 'Pinokio'
|
||||
- 'Manual'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: __Describe your environment__
|
||||
@@ -90,8 +76,8 @@ body:
|
||||
attributes:
|
||||
label: Version number
|
||||
description: |
|
||||
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. v6.0.2
|
||||
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. 3.6.1
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -99,17 +85,17 @@ body:
|
||||
id: browser-version
|
||||
attributes:
|
||||
label: Browser
|
||||
description: Your web browser and version, if you do not use the Launcher's provided GUI.
|
||||
description: Your web browser and version.
|
||||
placeholder: ex. Firefox 123.0b3
|
||||
validations:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: python-deps
|
||||
attributes:
|
||||
label: System Information
|
||||
label: Python dependencies
|
||||
description: |
|
||||
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
|
||||
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
|
||||
@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 20
|
||||
- name: setup node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
node-version: '18'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
version: 8.15.6
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
|
||||
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -18,6 +18,5 @@
|
||||
|
||||
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
|
||||
- [ ] _Tests added / updated (if applicable)_
|
||||
- [ ] _❗Changes to a redux slice have a corresponding migration_
|
||||
- [ ] _Documentation added / updated (if applicable)_
|
||||
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
|
||||
|
||||
21
.github/workflows/build-container.yml
vendored
21
.github/workflows/build-container.yml
vendored
@@ -45,9 +45,6 @@ jobs:
|
||||
steps:
|
||||
- name: Free up more disk space on the runner
|
||||
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
|
||||
# the /mnt dir has 70GBs of free space
|
||||
# /dev/sda1 74G 28K 70G 1% /mnt
|
||||
# According to some online posts the /mnt is not always there, so checking before setting docker to use it
|
||||
run: |
|
||||
echo "----- Free space before cleanup"
|
||||
df -h
|
||||
@@ -55,11 +52,6 @@ jobs:
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -d /mnt ]; then
|
||||
sudo chmod -R 777 /mnt
|
||||
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
|
||||
sudo systemctl restart docker
|
||||
fi
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
@@ -84,6 +76,9 @@ jobs:
|
||||
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
@@ -105,12 +100,10 @@ jobs:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
build-args: |
|
||||
GPU_DRIVER=${{ matrix.gpu-driver }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
# cache-from: |
|
||||
# type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
# type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
# cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Builds and uploads python build artifacts.
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build wheel
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
@@ -27,12 +27,19 @@ jobs:
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: build wheel
|
||||
id: build_wheel
|
||||
run: ./scripts/build_wheel.sh
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.build_wheel.outputs.DIST_PATH }}
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: installer
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
||||
7
.github/workflows/frontend-checks.yml
vendored
7
.github/workflows/frontend-checks.yml
vendored
@@ -44,12 +44,7 @@ jobs:
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
|
||||
7
.github/workflows/frontend-tests.yml
vendored
7
.github/workflows/frontend-tests.yml
vendored
@@ -44,12 +44,7 @@ jobs:
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
|
||||
30
.github/workflows/lfs-checks.yml
vendored
30
.github/workflows/lfs-checks.yml
vendored
@@ -1,30 +0,0 @@
|
||||
# Checks that large files and LFS-tracked files are properly checked in with pointer format.
|
||||
# Uses https://github.com/ppremk/lfs-warning to detect LFS issues.
|
||||
|
||||
name: 'lfs checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
lfs-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
# Required to label and comment on the PRs
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check lfs files
|
||||
uses: ppremk/lfs-warning@v3.3
|
||||
28
.github/workflows/python-checks.yml
vendored
28
.github/workflows/python-checks.yml
vendored
@@ -34,9 +34,6 @@ on:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
env:
|
||||
# uv requires a venv by default - but for this, we can simply use the system python
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
@@ -46,12 +43,7 @@ jobs:
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
@@ -60,23 +52,25 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup uv
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: check pypi classifiers
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv run --no-project scripts/check_classifiers.py ./pyproject.toml
|
||||
run: pip install ruff==0.6.0
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv tool run ruff@0.11.2 check --output-format=github .
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv tool run ruff@0.11.2 format --check .
|
||||
run: ruff format --check .
|
||||
shell: bash
|
||||
|
||||
40
.github/workflows/python-tests.yml
vendored
40
.github/workflows/python-tests.yml
vendored
@@ -39,15 +39,24 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
- '3.12'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-24.04
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
@@ -61,22 +70,14 @@ jobs:
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
# https://github.com/nschloe/action-cached-lfs-checkout
|
||||
uses: nschloe/action-cached-lfs-checkout@f46300cd8952454b9f0a21a3d133d4bd5684cfc2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
@@ -85,25 +86,20 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable ".[test]"
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
always_run: true
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-wheel.yml
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
39
.github/workflows/typegen-checks.yml
vendored
39
.github/workflows/typegen-checks.yml
vendored
@@ -39,52 +39,27 @@ jobs:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up more disk space on the runner
|
||||
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
|
||||
run: |
|
||||
echo "----- Free space before cleanup"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
- name: check for changed files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
src:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: '3.11'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
- name: install python dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable .
|
||||
run: pip3 install --use-pep517 --editable="."
|
||||
|
||||
- name: install frontend dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
@@ -97,7 +72,7 @@ jobs:
|
||||
|
||||
- name: generate schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
run: make frontend-typegen
|
||||
shell: bash
|
||||
|
||||
- name: compare files
|
||||
|
||||
68
.github/workflows/uv-lock-checks.yml
vendored
68
.github/workflows/uv-lock-checks.yml
vendored
@@ -1,68 +0,0 @@
|
||||
# Check the `uv` lockfile for consistency with `pyproject.toml`.
|
||||
#
|
||||
# If this check fails, you should run `uv lock` to update the lockfile.
|
||||
|
||||
name: 'uv lock checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
uv-lock-checks:
|
||||
env:
|
||||
# uv requires a venv by default - but for this, we can simply use the system python
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
uvlock-pyprojecttoml:
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.uvlock-pyprojecttoml_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
|
||||
- name: check lockfile
|
||||
if: ${{ steps.changed-files.outputs.uvlock-pyprojecttoml_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv lock --locked # this will exit with 1 if the lockfile is not consistent with pyproject.toml
|
||||
shell: bash
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,7 +180,6 @@ cython_debug/
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
.vscode/
|
||||
.zed/
|
||||
|
||||
# source installer files
|
||||
installer/*zip
|
||||
@@ -189,6 +188,3 @@ installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
.claude/
|
||||
|
||||
@@ -4,29 +4,21 @@ repos:
|
||||
hooks:
|
||||
- id: black
|
||||
name: black
|
||||
stages: [pre-commit]
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
||||
|
||||
- id: flake8
|
||||
name: flake8
|
||||
stages: [pre-commit]
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: flake8
|
||||
types: [python]
|
||||
|
||||
- id: isort
|
||||
name: isort
|
||||
stages: [pre-commit]
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: isort
|
||||
types: [python]
|
||||
|
||||
- id: uvlock
|
||||
name: uv lock
|
||||
stages: [pre-commit]
|
||||
language: system
|
||||
entry: uv lock
|
||||
files: ^pyproject\.toml$
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
10
Makefile
10
Makefile
@@ -16,7 +16,7 @@ help:
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
@@ -64,13 +64,13 @@ frontend-dev:
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Tag the release
|
||||
wheel:
|
||||
cd scripts && ./build_wheel.sh
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
||||
# Tag the release
|
||||
tag-release:
|
||||
cd scripts && ./tag_release.sh
|
||||
cd installer && ./tag_release.sh
|
||||
|
||||
# Generate the OpenAPI Schema for the app
|
||||
openapi:
|
||||
|
||||
@@ -22,10 +22,6 @@
|
||||
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
|
||||
# GPU_DRIVER=cuda #| rocm
|
||||
|
||||
## If you are using ROCM, you will need to ensure that the render group within the container and the host system use the same group ID.
|
||||
## To obtain the group ID of the render group on the host system, run `getent group render` and grab the number.
|
||||
# RENDER_GROUP_ID=
|
||||
|
||||
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
|
||||
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
|
||||
# CONTAINER_UID=1000
|
||||
|
||||
@@ -1,11 +1,68 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
#### Web UI ------------------------------------
|
||||
## Builder stage
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
build-essential \
|
||||
git
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt
|
||||
USER ubuntu
|
||||
|
||||
# Install python and create the venv
|
||||
RUN uv python install ${PYTHON_VERSION} && \
|
||||
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
COPY invokeai ./invokeai
|
||||
COPY pyproject.toml ./
|
||||
|
||||
# Editable mode helps use the same image for development:
|
||||
# the local working copy can be bind-mounted into the image
|
||||
# at path defined by ${INVOKEAI_SRC}
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
|
||||
fi && \
|
||||
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
|
||||
|
||||
#### Build the Web UI ------------------------------------
|
||||
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@10.x && corepack enable
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
@@ -13,73 +70,61 @@ RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
## Backend ---------------------------------------
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
ulimit -n 30000 && \
|
||||
uv sync --extra $GPU_DRIVER --frozen
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids" && groupadd render
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
@@ -90,18 +135,3 @@ RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${IN
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
|
||||
# this should not increase image size because we've already installed dependencies
|
||||
# in a previous layer
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
ulimit -n 30000 && \
|
||||
uv pip install -e .[$GPU_DRIVER]
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
#### Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
## Backend ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev \
|
||||
wget
|
||||
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
ulimit -n 30000 && \
|
||||
uv sync --extra $GPU_DRIVER --frozen
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
if [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
wget -O /tmp/amdgpu-install.deb \
|
||||
https://repo.radeon.com/amdgpu-install/6.3.4/ubuntu/noble/amdgpu-install_6.3.60304-1_all.deb && \
|
||||
apt install -y /tmp/amdgpu-install.deb && \
|
||||
apt update && \
|
||||
amdgpu-install --usecase=rocm -y && \
|
||||
apt-get autoclean && \
|
||||
apt clean && \
|
||||
rm -rf /tmp/* /var/tmp/* && \
|
||||
usermod -a -G render ubuntu && \
|
||||
usermod -a -G video ubuntu && \
|
||||
echo "\\n/opt/rocm/lib\\n/opt/rocm/lib64" >> /etc/ld.so.conf.d/rocm.conf && \
|
||||
ldconfig && \
|
||||
update-alternatives --auto rocm; \
|
||||
fi
|
||||
|
||||
## Heathen711: Leaving this for review input, will remove before merge
|
||||
# RUN --mount=type=cache,target=/var/cache/apt \
|
||||
# --mount=type=cache,target=/var/lib/apt \
|
||||
# if [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
# groupadd render && \
|
||||
# usermod -a -G render ubuntu && \
|
||||
# usermod -a -G video ubuntu; \
|
||||
# fi
|
||||
|
||||
## Link amdgpu.ids for ROCm builds
|
||||
## contributed by https://github.com/Rubonnek
|
||||
# RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
# ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
|
||||
# this should not increase image size because we've already installed dependencies
|
||||
# in a previous layer
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
ulimit -n 30000 && \
|
||||
uv pip install -e .[$GPU_DRIVER]
|
||||
|
||||
@@ -47,9 +47,8 @@ services:
|
||||
|
||||
invokeai-rocm:
|
||||
<<: *invokeai
|
||||
environment:
|
||||
- AMD_VISIBLE_DEVICES=all
|
||||
- RENDER_GROUP_ID=${RENDER_GROUP_ID}
|
||||
runtime: amd
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
profiles:
|
||||
- rocm
|
||||
|
||||
@@ -21,17 +21,6 @@ _=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
|
||||
# ensure the UID is correct
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
## ROCM specific configuration
|
||||
# render group within the container must match the host render group
|
||||
# otherwise the container will not be able to access the host GPU.
|
||||
if [[ -v "RENDER_GROUP_ID" ]] && [[ ! -z "${RENDER_GROUP_ID}" ]]; then
|
||||
# ensure the render group exists
|
||||
groupmod -g ${RENDER_GROUP_ID} render
|
||||
usermod -a -G render ${USER}
|
||||
usermod -a -G video ${USER}
|
||||
fi
|
||||
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
# We do not install openssh-server in the image by default to avoid bloat.
|
||||
# but it is useful to have the full SSH server e.g. on Runpod.
|
||||
|
||||
@@ -13,7 +13,7 @@ run() {
|
||||
|
||||
# parse .env file for build args
|
||||
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
|
||||
profile="$(awk -F '=' '/GPU_DRIVER=/ {print $2}' .env)"
|
||||
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
|
||||
|
||||
# default to 'cuda' profile
|
||||
[[ -z "$profile" ]] && profile="cuda"
|
||||
@@ -30,7 +30,7 @@ run() {
|
||||
|
||||
printf "%s\n" "starting service $service_name"
|
||||
docker compose --profile "$profile" up -d "$service_name"
|
||||
docker compose --profile "$profile" logs -f
|
||||
docker compose logs -f
|
||||
}
|
||||
|
||||
run
|
||||
|
||||
139
docs/RELEASE.md
139
docs/RELEASE.md
@@ -1,50 +1,41 @@
|
||||
# Release Process
|
||||
|
||||
The Invoke application is published as a python package on [PyPI]. This includes both a source distribution and built distribution (a wheel).
|
||||
The app is published in twice, in different build formats.
|
||||
|
||||
Most users install it with the [Launcher](https://github.com/invoke-ai/launcher/), others with `pip`.
|
||||
|
||||
The launcher uses GitHub as the source of truth for available releases.
|
||||
|
||||
## Broad Strokes
|
||||
|
||||
- Merge all changes and bump the version in the codebase.
|
||||
- Tag the release commit.
|
||||
- Wait for the release workflow to complete.
|
||||
- Approve the PyPI publish jobs.
|
||||
- Write GH release notes.
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`.
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Ensure all commits that should be in the release are merged, and you have pulled them locally.
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
|
||||
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
|
||||
|
||||
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
|
||||
|
||||
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
|
||||
The release may also be dispatched [manually].
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run checks and tests, then two final publish jobs.
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job ensures that the `invokeai` python package version specifier matches the tag for the release. The version specifier is pulled from the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
@@ -52,47 +43,62 @@ This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
Next, these jobs run and must pass. They are the same jobs that are run for every PR.
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
- **`typegen-checks`**: ensures the frontend and backend types are synced
|
||||
|
||||
#### `build-wheel` Job
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `./scripts/build_wheel.sh` and uploads `dist.zip`, which contains the wheel and unarchived build.
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
|
||||
You don't need to download or test these artifacts.
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval. Time to test the installer.
|
||||
|
||||
It's possible to test the python package before it gets published to PyPI. We've never had problems with it, so it's not necessary to do this.
|
||||
Because the installer pulls from PyPI, and we haven't published to PyPI yet, you will need to install from the wheel:
|
||||
|
||||
But, if you want to be extra-super careful, here's how to test it:
|
||||
- Download and unzip `dist.zip` and the installer from the **Summary** tab of the workflow
|
||||
- Run the installer script using the `--wheel` CLI arg, pointing at the wheel:
|
||||
|
||||
- Download the `dist.zip` build artifact from the `build-wheel` job
|
||||
- Unzip it and find the wheel file
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
|
||||
- Test the app
|
||||
```sh
|
||||
./install.sh --wheel ../InvokeAI-4.0.0rc6-py3-none-any.whl
|
||||
```
|
||||
|
||||
- Install to a temporary directory so you get the new user experience
|
||||
- Download a model and generate
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation as if the installer got the wheel from PyPI.
|
||||
|
||||
##### Something isn't right
|
||||
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?) and start over.
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?).
|
||||
|
||||
Now you can start from the top:
|
||||
|
||||
- Fix the issues and PR the fixes per usual
|
||||
- Get the PR approved and merged per usual
|
||||
- Switch to `main` and pull in the fixes
|
||||
- Run `make tag-release` to move the tag to `HEAD` (which has the fixes) and kick off the release workflow again
|
||||
- Re-do the sanity check
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will not run if any of the previous jobs fail.
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi` - typically you select both)
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
@@ -107,33 +113,46 @@ If there are no incidents, contact @hipsterusername or @lstein, who have owner a
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release for some reason:
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
|
||||
- Approve this publish job without approving the prod publish
|
||||
- Let it finish
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/), making sure to use the Test PyPI index URL: `https://test.pypi.org/simple/`
|
||||
- Test the app
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
It's a good idea to wait to approve and run this job until you have the release notes ready!
|
||||
## Publish the GitHub Release with installer
|
||||
|
||||
## Prep and publish the GitHub Release
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
2. The **Generate release notes** button automatically inserts the changelog and new contributors. Make sure to select the correct tags for this release and the last stable release. GH often selects the wrong tags - do this manually.
|
||||
3. Write the release notes, describing important changes. Contributions from community members should be shouted out. Use the GH-generated changelog to see all contributors. If there are Weblate translation updates, open that PR and shout out every person who contributed a translation.
|
||||
4. Check **Set as a pre-release** if it's a pre-release.
|
||||
5. Approve and wait for the `publish-pypi` job to finish if you haven't already.
|
||||
6. Publish the GH release.
|
||||
7. Post the release in Discord in the [releases](https://discord.com/channels/1020123559063990373/1149260708098359327) channel with abbreviated notes. For example:
|
||||
> Invoke v5.7.0 (stable): <https://github.com/invoke-ai/InvokeAI/releases/tag/v5.7.0>
|
||||
>
|
||||
> It's a pretty big one - Form Builder, Metadata Nodes (thanks @SkunkWorxDark!), and much more.
|
||||
8. Right click the message in releases and copy the link to it. Then, post that link in the [new-release-discussion](https://discord.com/channels/1020123559063990373/1149506274971631688) channel. For example:
|
||||
> Invoke v5.7.0 (stable): <https://discord.com/channels/1020123559063990373/1149260708098359327/1344521744916021248>
|
||||
1. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
1. Use `scripts/get_external_contributions.py` to get a list of external contributions to shout out in the release notes.
|
||||
1. Upload the zip file created in **`build`** job into the Assets section of the release notes.
|
||||
1. Check **Set as a pre-release** if it's a pre-release.
|
||||
1. Check **Create a discussion for this release**.
|
||||
1. Publish the release.
|
||||
1. Announce the release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
|
||||
## Manual Release
|
||||
|
||||
@@ -141,10 +160,12 @@ The `release` workflow can be dispatched manually. You must dispatch the workflo
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
|
||||
@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||
folder for more examples:
|
||||
|
||||
```py
|
||||
from .cool_node import ResizeInvocation
|
||||
from .cool_node import CoolInvocation
|
||||
```
|
||||
|
||||
## Creating A New Invocation
|
||||
@@ -69,10 +69,7 @@ The first set of things we need to do when creating a new Invocation are -
|
||||
So let us do that.
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -106,12 +103,8 @@ create your own custom field types later in this guide. For now, let's go ahead
|
||||
and use it.
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -135,12 +128,8 @@ image: ImageField = InputField(description="The input image")
|
||||
Great. Now let us create our other inputs for `width` and `height`
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -174,13 +163,8 @@ that are provided by it by InvokeAI.
|
||||
Let us create this function first.
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -207,14 +191,8 @@ all the necessary info related to image outputs. So let us use that.
|
||||
We will cover how to create your own output types later in this guide.
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation('resize')
|
||||
@@ -239,15 +217,9 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
||||
So let's do that.
|
||||
|
||||
```python
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
|
||||
|
||||
@invocation("resize")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
|
||||
@@ -265,7 +265,7 @@ If the key is unrecognized, this call raises an
|
||||
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the database.
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
@@ -718,7 +718,7 @@ When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
retrieved from the HuggingFace and Civitai model repositories.
|
||||
|
||||
The probed values can be overridden by providing a dictionary in the
|
||||
The probed values can be overriden by providing a dictionary in the
|
||||
optional `config` argument passed to `import_model()`. You may provide
|
||||
overriding values for any of the model's configuration
|
||||
attributes. Here is an example of setting the
|
||||
@@ -841,7 +841,7 @@ variable.
|
||||
|
||||
#### installer.start(invoker)
|
||||
|
||||
The `start` method is called by the API initialization routines when
|
||||
The `start` method is called by the API intialization routines when
|
||||
the API starts up. Its effect is to call `sync_to_config()` to
|
||||
synchronize the model record store database with what's currently on
|
||||
disk.
|
||||
|
||||
@@ -16,7 +16,7 @@ We thank [all contributors](https://github.com/invoke-ai/InvokeAI/graphs/contrib
|
||||
- @psychedelicious (Spencer Mabrito) - Web Team Leader
|
||||
- @joshistoast (Josh Corbett) - Web Development
|
||||
- @cheerio (Mary Rogers) - Lead Engineer & Web App Development
|
||||
- @ebr (Eugene Brodsky) - Cloud/DevOps/Software engineer; your friendly neighbourhood cluster-autoscaler
|
||||
- @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
|
||||
- @sunija - Standalone version
|
||||
- @brandon (Brandon Rising) - Platform, Infrastructure, Backend Systems
|
||||
- @ryanjdick (Ryan Dick) - Machine Learning & Training
|
||||
|
||||
@@ -18,19 +18,9 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
|
||||
3. This repository uses Git LFS to manage large files. To ensure all assets are downloaded:
|
||||
- Install git-lfs → [Download here](https://git-lfs.com/)
|
||||
- Enable automatic LFS fetching for this repository:
|
||||
```shell
|
||||
git config lfs.fetchinclude "*"
|
||||
```
|
||||
- Fetch files from LFS (only needs to be done once; subsequent `git pull` will fetch changes automatically):
|
||||
```
|
||||
git lfs pull
|
||||
```
|
||||
4. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
|
||||
5. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
|
||||
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
|
||||
|
||||
@@ -41,22 +31,22 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
7. Install the frontend dev toolchain, paying attention to versions:
|
||||
6. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
8. Do a production build of the frontend:
|
||||
7. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
|
||||
@@ -64,7 +54,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
pnpm build
|
||||
```
|
||||
|
||||
9. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
|
||||
## Updating the UI
|
||||
|
||||
|
||||
183
docs/faq.md
183
docs/faq.md
@@ -1,18 +1,26 @@
|
||||
# FAQ
|
||||
|
||||
!!! info "How to Reinstall"
|
||||
|
||||
Many issues can be resolved by re-installing the application. You won't lose any data by re-installing. We suggest downloading the [latest release](https://github.com/invoke-ai/InvokeAI/releases/latest) and using it to re-install the application. Consult the [installer guide](./installation/installer.md) for more information.
|
||||
|
||||
When you run the installer, you'll have an option to select the version to install. If you aren't ready to upgrade, you choose the current version to fix a broken install.
|
||||
|
||||
If the troubleshooting steps on this page don't get you up and running, please either [create an issue] or hop on [discord] for help.
|
||||
|
||||
## How to Install
|
||||
|
||||
Follow the [Quick Start guide](./installation/quick_start.md) to install Invoke.
|
||||
You can download the latest installers [here](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
Note that any releases marked as _pre-release_ are in a beta state. You may experience some issues, but we appreciate your help testing those! For stable/reliable installations, please install the [latest release].
|
||||
|
||||
## Downloading models and using existing models
|
||||
|
||||
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
|
||||
|
||||
## Missing models after updating from v3
|
||||
## Missing models after updating to v4
|
||||
|
||||
If you find some models are missing after updating from v3, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
|
||||
If you find some models are missing after updating to v4, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
|
||||
|
||||
You can use the `Scan Folder` tab in the Model Manager UI to fix this. The models will either be in the old, now-unused `autoimport` folder, or your `models` folder.
|
||||
|
||||
@@ -29,27 +37,115 @@ Follow the same steps to scan and import the missing models.
|
||||
## Slow generation
|
||||
|
||||
- Check the [system requirements] to ensure that your system is capable of generating images.
|
||||
- Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
|
||||
- Check that your generations are happening on your GPU (if you have one). Invoke will log what is being used for generation upon startup. If your GPU isn't used, re-install to and ensure you select the appropriate GPU option.
|
||||
- If you are on Windows with an Nvidia GPU, you may have exceeded your GPU's VRAM capacity and are triggering Nvidia's "sysmem fallback". There's a guide to opt out of this behaviour in the [Low-VRAM mode guide](./features/low-vram.md).
|
||||
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
||||
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
||||
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
||||
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
|
||||
|
||||
## Shared GPU Memory (Windows)
|
||||
|
||||
!!! tip "Nvidia GPUs with driver 536.40"
|
||||
|
||||
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
|
||||
|
||||
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
|
||||
|
||||
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
|
||||
|
||||
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
|
||||
|
||||
Here's how to get the python path required in the linked guide:
|
||||
|
||||
- Run `invoke.bat`.
|
||||
- Select option 2 for developer console.
|
||||
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
|
||||
|
||||
## Installer cannot find python (Windows)
|
||||
|
||||
Ensure that you checked **Add python.exe to PATH** when installing Python. This can be found at the bottom of the Python Installer window. If you already have Python installed, you can re-run the python installer, choose the Modify option and check the box.
|
||||
|
||||
## Triton error on startup
|
||||
|
||||
This can be safely ignored. Invoke doesn't use Triton, but if you are on Linux and wish to dismiss the error, you can install Triton.
|
||||
This can be safely ignored. InvokeAI doesn't use Triton, but if you are on Linux and wish to dismiss the error, you can install Triton.
|
||||
|
||||
## Unable to Copy on Firefox
|
||||
## Updated to 3.4.0 and xformers can’t load C++/CUDA
|
||||
|
||||
Firefox does not allow Invoke to directly access the clipboard by default. As a result, you may be unable to use certain copy functions. You can fix this by configuring Firefox to allow access to write to the clipboard:
|
||||
An issue occurred with your PyTorch update. Follow these steps to fix :
|
||||
|
||||
- Go to `about:config` and click the Accept button
|
||||
- Search for `dom.events.asyncClipboard.clipboardItem`
|
||||
- Set it to `true` by clicking the toggle button
|
||||
- Restart Firefox
|
||||
1. Launch your invoke.bat / invoke.sh and select the option to open the developer console
|
||||
2. Run:`pip install ".[xformers]" --upgrade --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121`
|
||||
- If you run into an error with `typing_extensions`, re-open the developer console and run: `pip install -U typing-extensions`
|
||||
|
||||
Note that v3.4.0 is an old, unsupported version. Please upgrade to the [latest release].
|
||||
|
||||
## Install failed and says `pip` is out of date
|
||||
|
||||
An out of date `pip` typically won't cause an installation to fail. The cause of the error can likely be found above the message that says `pip` is out of date.
|
||||
|
||||
If you saw that warning but the install went well, don't worry about it (but you can update `pip` afterwards if you'd like).
|
||||
|
||||
## Replicate image found online
|
||||
|
||||
Most example images with prompts that you'll find on the internet have been generated using different software, so you can't expect to get identical results. In order to reproduce an image, you need to replicate the exact settings and processing steps, including (but not limited to) the model, the positive and negative prompts, the seed, the sampler, the exact image size, any upscaling steps, etc.
|
||||
|
||||
## OSErrors on Windows while installing dependencies
|
||||
|
||||
During a zip file installation or an update, installation stops with an error like this:
|
||||
|
||||
{:width="800px"}
|
||||
|
||||
To resolve this, re-install the application as described above.
|
||||
|
||||
## HuggingFace install failed due to invalid access token
|
||||
|
||||
Some HuggingFace models require you to authenticate using an [access token].
|
||||
|
||||
Invoke doesn't manage this token for you, but it's easy to set it up:
|
||||
|
||||
- Follow the instructions in the link above to create an access token. Copy it.
|
||||
- Run the launcher script.
|
||||
- Select option 2 (developer console).
|
||||
- Paste the following command:
|
||||
|
||||
```sh
|
||||
python -c "import huggingface_hub; huggingface_hub.login()"
|
||||
```
|
||||
|
||||
- Paste your access token when prompted and press Enter. You won't see anything when you paste it.
|
||||
- Type `n` if prompted about git credentials.
|
||||
|
||||
If you get an error, try the command again - maybe the token didn't paste correctly.
|
||||
|
||||
Once your token is set, start Invoke and try downloading the model again. The installer will automatically use the access token.
|
||||
|
||||
If the install still fails, you may not have access to the model.
|
||||
|
||||
## Stable Diffusion XL generation fails after trying to load UNet
|
||||
|
||||
InvokeAI is working in other respects, but when trying to generate
|
||||
images with Stable Diffusion XL you get a "Server Error". The text log
|
||||
in the launch window contains this log line above several more lines of
|
||||
error messages:
|
||||
|
||||
`INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet`
|
||||
|
||||
This failure mode occurs when there is a network glitch during
|
||||
downloading the very large SDXL model.
|
||||
|
||||
To address this, first go to the Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then, click the HuggingFace tab,
|
||||
paste the Repo ID stabilityai/stable-diffusion-xl-base-1.0 and install
|
||||
the model.
|
||||
|
||||
## Package dependency conflicts during installation or update
|
||||
|
||||
If you have previously installed InvokeAI or another Stable Diffusion
|
||||
package, the installer may occasionally pick up outdated libraries and
|
||||
either the installer or `invoke` will fail with complaints about
|
||||
library conflicts.
|
||||
|
||||
To resolve this, re-install the application as described above.
|
||||
|
||||
## Invalid configuration file
|
||||
|
||||
Everything seems to install ok, you get a `ValidationError` when starting up the app.
|
||||
@@ -58,9 +154,64 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
|
||||
|
||||
Check the [configuration docs] for more detail about the settings and how to specify them.
|
||||
|
||||
## Out of Memory Errors
|
||||
## `ModuleNotFoundError: No module named 'controlnet_aux'`
|
||||
|
||||
The models are large, VRAM is expensive, and you may find yourself faced with Out of Memory errors when generating images. Follow our [Low-VRAM mode guide](./features/low-vram.md) to configure Invoke to prevent these.
|
||||
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
|
||||
|
||||
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
|
||||
|
||||
- Run the Invoke launcher
|
||||
- Choose the developer console option
|
||||
- Run this command: `pip cache remove controlnet_aux`
|
||||
- Close the terminal window
|
||||
- Download and run the [installer][latest release], selecting your current install location
|
||||
|
||||
## Out of Memory Issues
|
||||
|
||||
The models are large, VRAM is expensive, and you may find yourself
|
||||
faced with Out of Memory errors when generating images. Here are some
|
||||
tips to reduce the problem:
|
||||
|
||||
!!! info "Optimizing for GPU VRAM"
|
||||
|
||||
=== "4GB VRAM GPU"
|
||||
|
||||
This should be adequate for 512x512 pixel images using Stable Diffusion 1.5
|
||||
and derived models, provided that you do not use the NSFW checker. It won't be loaded unless you go into the UI settings and turn it on.
|
||||
|
||||
If you are on a CUDA-enabled GPU, we will automatically use xformers or torch-sdp to reduce VRAM requirements, though you can explicitly configure this. See the [configuration docs].
|
||||
|
||||
=== "6GB VRAM GPU"
|
||||
|
||||
This is a border case. Using the SD 1.5 series you should be able to
|
||||
generate images up to 640x640 with the NSFW checker enabled, and up to
|
||||
1024x1024 with it disabled.
|
||||
|
||||
If you run into persistent memory issues there are a series of
|
||||
environment variables that you can set before launching InvokeAI that
|
||||
alter how the PyTorch machine learning library manages memory. See
|
||||
<https://pytorch.org/docs/stable/notes/cuda.html#memory-management> for
|
||||
a list of these tweaks.
|
||||
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will convert the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
@@ -102,6 +253,8 @@ Note the differences between memory allocated as chunks in an arena vs. memory a
|
||||
|
||||
[model install docs]: ./installation/models.md
|
||||
[system requirements]: ./installation/requirements.md
|
||||
[latest release]: https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
[create an issue]: https://github.com/invoke-ai/InvokeAI/issues
|
||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||
[configuration docs]: ./configuration.md
|
||||
[access token]: https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens
|
||||
|
||||
@@ -28,13 +28,11 @@ It is possible to fine-tune the settings for best performance or if you still ge
|
||||
|
||||
## Details and fine-tuning
|
||||
|
||||
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
|
||||
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
|
||||
|
||||
- Partial model loading (`enable_partial_loading`)
|
||||
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
|
||||
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
|
||||
- Working memory (`device_working_mem_gb`)
|
||||
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
|
||||
- Partial model loading
|
||||
- Dynamic RAM and VRAM cache sizes
|
||||
- Working memory
|
||||
|
||||
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
|
||||
|
||||
@@ -52,16 +50,6 @@ As described above, you can enable partial model loading by adding this line to
|
||||
enable_partial_loading: true
|
||||
```
|
||||
|
||||
### PyTorch CUDA allocator config
|
||||
|
||||
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
|
||||
|
||||
```yaml
|
||||
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
|
||||
```
|
||||
|
||||
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
|
||||
|
||||
### Dynamic RAM and VRAM cache sizes
|
||||
|
||||
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
|
||||
@@ -79,33 +67,23 @@ As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings ar
|
||||
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
# The default max cache RAM size is logged on InvokeAI startup. It is determined based on your system RAM / VRAM.
|
||||
# You can override the default value by setting `max_cache_ram_gb`.
|
||||
# Increasing `max_cache_ram_gb` will increase the amount of RAM used to cache inactive models, resulting in faster model
|
||||
# reloads for the cached models.
|
||||
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
|
||||
# might be a good value to achieve aggressive model caching.
|
||||
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
|
||||
# For example, if your system has 32GB RAM, 28GB is a good value.
|
||||
max_cache_ram_gb: 28
|
||||
|
||||
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
|
||||
# consideration the VRAM used by other processes).
|
||||
# You can override the default value by setting `max_cache_vram_gb`.
|
||||
# CAUTION: Most users should not manually set this value. See warning below.
|
||||
max_cache_vram_gb: 16
|
||||
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
|
||||
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
|
||||
# 18GB is a good value.
|
||||
max_cache_vram_gb: 18
|
||||
```
|
||||
|
||||
!!! warning "Max safe value for `max_cache_vram_gb`"
|
||||
!!! tip "Max safe value for `max_cache_vram_gb`"
|
||||
|
||||
Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.
|
||||
|
||||
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
|
||||
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
|
||||
|
||||
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
|
||||
|
||||
Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.
|
||||
|
||||
### Working memory
|
||||
|
||||
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
|
||||
@@ -131,15 +109,6 @@ device_working_mem_gb: 4
|
||||
|
||||
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
|
||||
|
||||
### Keeping a RAM weight copy
|
||||
|
||||
Invoke has the option of keeping a RAM copy of all model weights, even when they are loaded onto the GPU. This optimization is _on_ by default, and enables faster model switching and LoRA patching. Disabling this feature will reduce the average RAM load while running Invoke (peak RAM likely won't change), at the cost of slower model switching and LoRA patching. If you have limited RAM, you can disable this optimization:
|
||||
|
||||
```yaml
|
||||
# Set to false to reduce the average RAM usage at the cost of slower model switching and LoRA patching.
|
||||
keep_ram_copy_of_weights: false
|
||||
```
|
||||
|
||||
### Disabling Nvidia sysmem fallback (Windows only)
|
||||
|
||||
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
|
||||
@@ -158,19 +127,3 @@ It is strongly suggested to disable this feature:
|
||||
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
|
||||
|
||||
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Windows page file
|
||||
|
||||
Invoke has high virtual memory (a.k.a. 'committed memory') requirements. This can cause issues on Windows if the page file size limits are hit. (See this issue for the technical details on why this happens: https://github.com/invoke-ai/InvokeAI/issues/7563).
|
||||
|
||||
If you run out of page file space, InvokeAI may crash. Often, these crashes will happen with one of the following errors:
|
||||
|
||||
- InvokeAI exits with Windows error code `3221225477`
|
||||
- InvokeAI crashes without an error, but `eventvwr.msc` reveals an error with code `0xc0000005` (the hex equivalent of `3221225477`)
|
||||
|
||||
If you are running out of page file space, try the following solutions:
|
||||
|
||||
- Make sure that you have sufficient disk space for the page file to grow. Watch your disk usage as Invoke runs. If it climbs near 100% leading up to the crash, then this is very likely the source of the issue. Clear out some disk space to resolve the issue.
|
||||
- Make sure that your page file is set to "System managed size" (this is the default) rather than a custom size. Under the "System managed size" policy, the page file will grow dynamically as needed.
|
||||
|
||||
121
docs/installation/legacy_scripts.md
Normal file
121
docs/installation/legacy_scripts.md
Normal file
@@ -0,0 +1,121 @@
|
||||
# Legacy Scripts
|
||||
|
||||
!!! warning "Legacy Scripts"
|
||||
|
||||
We recommend using the Invoke Launcher to install and update Invoke. It's a desktop application for Windows, macOS and Linux. It takes care of a lot of nitty gritty details for you.
|
||||
|
||||
Follow the [quick start guide](./quick_start.md) to get started.
|
||||
|
||||
!!! tip "Use the installer to update"
|
||||
|
||||
Using the installer for updates will not erase any of your data (images, models, boards, etc). It only updates the core libraries used to run Invoke.
|
||||
|
||||
Simply use the same path you installed to originally to update your existing installation.
|
||||
|
||||
Both release and pre-release versions can be installed using the installer. It also supports install through a wheel if needed.
|
||||
|
||||
Be sure to review the [installation requirements] and ensure your system has everything it needs to install Invoke.
|
||||
|
||||
## Getting the Latest Installer
|
||||
|
||||
Download the `InvokeAI-installer-vX.Y.Z.zip` file from the [latest release] page. It is at the bottom of the page, under **Assets**.
|
||||
|
||||
After unzipping the installer, you should have a `InvokeAI-Installer` folder with some files inside, including `install.bat` and `install.sh`.
|
||||
|
||||
## Running the Installer
|
||||
|
||||
!!! tip
|
||||
|
||||
Windows users should first double-click the `WinLongPathsEnabled.reg` file to prevent a failed installation due to long file paths.
|
||||
|
||||
Double-click the install script:
|
||||
|
||||
=== "Windows"
|
||||
|
||||
```sh
|
||||
install.bat
|
||||
```
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
```sh
|
||||
install.sh
|
||||
```
|
||||
|
||||
!!! info "Running the Installer from the commandline"
|
||||
|
||||
You can also run the install script from cmd/powershell (Windows) or terminal (Linux/macOS).
|
||||
|
||||
!!! warning "Untrusted Publisher (Windows)"
|
||||
|
||||
You may get a popup saying the file comes from an `Untrusted Publisher`. Click `More Info` and `Run Anyway` to get past this.
|
||||
|
||||
The installation process is simple, with a few prompts:
|
||||
|
||||
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
|
||||
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
|
||||
- Select a GPU device.
|
||||
|
||||
!!! info "Slow Installation"
|
||||
|
||||
The installer needs to download several GB of data and install it all. It may appear to get stuck at 99.9% when installing `pytorch` or during a step labeled "Installing collected packages".
|
||||
|
||||
If it is stuck for over 10 minutes, something has probably gone wrong and you should close the window and restart.
|
||||
|
||||
## Running the Application
|
||||
|
||||
Find the install location you selected earlier. Double-click the launcher script to run the app:
|
||||
|
||||
=== "Windows"
|
||||
|
||||
```sh
|
||||
invoke.bat
|
||||
```
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
```sh
|
||||
invoke.sh
|
||||
```
|
||||
|
||||
Choose the first option to run the UI. After a series of startup messages, you'll see something like this:
|
||||
|
||||
```sh
|
||||
Uvicorn running on http://127.0.0.1:9090 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
Copy the URL into your browser and you should see the UI.
|
||||
|
||||
## Improved Outpainting with PatchMatch
|
||||
|
||||
PatchMatch is an extra add-on that can improve outpainting. Windows users are in luck - it works out of the box.
|
||||
|
||||
On macOS and Linux, a few extra steps are needed to set it up. See the [PatchMatch installation guide](./patchmatch.md).
|
||||
|
||||
## First-time Setup
|
||||
|
||||
You will need to [install some models] before you can generate.
|
||||
|
||||
Check the [configuration docs] for details on configuring the application.
|
||||
|
||||
## Updating
|
||||
|
||||
Updating is exactly the same as installing - download the latest installer, choose the latest version, enter your existing installation path, and the app will update. None of your data (images, models, boards, etc) will be erased.
|
||||
|
||||
!!! info "Dependency Resolution Issues"
|
||||
|
||||
We've found that pip's dependency resolution can cause issues when upgrading packages. One very common problem was pip "downgrading" torch from CUDA to CPU, but things broke in other novel ways.
|
||||
|
||||
The installer doesn't have this kind of problem, so we use it for updating as well.
|
||||
|
||||
## Installation Issues
|
||||
|
||||
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
|
||||
|
||||
[installation requirements]: ./requirements.md
|
||||
[FAQ]: ../faq.md
|
||||
[install some models]: ./models.md
|
||||
[configuration docs]: ../configuration.md
|
||||
[latest release]: https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
[create an issue]: https://github.com/invoke-ai/InvokeAI/issues
|
||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||
@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
@@ -64,51 +64,37 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
6. Determine the package specifier to use when installing. This is a performance optimization.
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
|
||||
7. Determine the torch backend to use for installation, if any. This is necessary to get the right version of torch installed. This is acheived by using [UV's built in torch support.](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection)
|
||||
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
|
||||
|
||||
=== "Invoke v5.12 and later"
|
||||
=== "Invoke v5 or later"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu128`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.3`.
|
||||
- **In all other cases, do not use a torch backend.**
|
||||
|
||||
=== "Invoke v5.10.0 to v5.11.0"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu126`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.2.4`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v5.0.0 to v5.9.1"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.1`.
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v4"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm5.2`.
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a torch backend in the previous step, you'll need to set the backend like this:
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --torch-backend=<VERSION> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -33,45 +33,30 @@ Hardware requirements vary significantly depending on model and image output siz
|
||||
|
||||
More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download and Set Up the Launcher
|
||||
## Step 2: Download
|
||||
|
||||
The Launcher manages your Invoke install. Follow these instructions to download and set up the Launcher.
|
||||
Download the most launcher for your operating system:
|
||||
|
||||
!!! info "Instructions for each OS"
|
||||
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
|
||||
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
|
||||
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
|
||||
|
||||
=== "Windows"
|
||||
## Step 3: Install or Update
|
||||
|
||||
- [Download for Windows](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition.Setup.latest.exe)
|
||||
- Run the `EXE` to install the Launcher and start it.
|
||||
- A desktop shortcut will be created; use this to run the Launcher in the future.
|
||||
- You can delete the `EXE` file you downloaded.
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- [Download for macOS](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest-arm64.dmg)
|
||||
- Open the `DMG` and drag the app into `Applications`.
|
||||
- Run the Launcher using its entry in `Applications`.
|
||||
- You can delete the `DMG` file you downloaded.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- [Download for Linux](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest.AppImage)
|
||||
- You may need to edit the `AppImage` file properties and make it executable.
|
||||
- Optionally move the file to a location that does not require admin privileges and add a desktop shortcut for it.
|
||||
- Run the Launcher by double-clicking the `AppImage` or the shortcut you made.
|
||||
|
||||
## Step 3: Install Invoke
|
||||
|
||||
Run the Launcher you just set up if you haven't already. Click **Install** and follow the instructions to install (or update) Invoke.
|
||||
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
|
||||
|
||||
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
|
||||
|
||||
!!! tip "Updating"
|
||||
!!! warning "Problem running the launcher on macOS"
|
||||
|
||||
The Launcher will check for updates for itself _and_ Invoke.
|
||||
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can either use the [legacy scripts](./legacy_scripts.md) to install, or manually flag the launcher as safe:
|
||||
|
||||
- When the Launcher detects an update is available for itself, you'll get a small popup window. Click through this and the Launcher will update itself.
|
||||
- When the Launcher detects an update for Invoke, you'll see a small green alert in the Launcher. Click that and follow the instructions to update Invoke.
|
||||
- Open the **Invoke-Installer-mac-arm64.dmg** file.
|
||||
- Drag the launcher to **Applications**.
|
||||
- Open a terminal.
|
||||
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
|
||||
|
||||
You should now be able to run the launcher.
|
||||
|
||||
## Step 4: Launch
|
||||
|
||||
@@ -114,24 +99,11 @@ We recommend watching our [Getting Started Playlist](https://www.youtube.com/pla
|
||||
- Using control layers and reference guides.
|
||||
- Refining images with advanced workflows.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If installation fails, retrying the install in Repair Mode may fix it. There's a checkbox to enable this on the Review step of the install flow.
|
||||
|
||||
If that doesn't fix it, [clearing the `uv` cache](https://docs.astral.sh/uv/reference/cli/#uv-cache-clean) might do the trick:
|
||||
|
||||
- Open and start the dev console (button at the bottom-left of the launcher).
|
||||
- Run `uv cache clean`.
|
||||
- Retry the installation. Enable Repair Mode for good measure.
|
||||
|
||||
If you are still unable to install, try installing to a different location and see if that works.
|
||||
|
||||
If you still have problems, ask for help on the Invoke [discord](https://discord.gg/ZmtBAhwWhy).
|
||||
|
||||
## Other Installation Methods
|
||||
|
||||
- You can install the Invoke application as a python package. See our [manual install](./manual.md) docs.
|
||||
- You can run Invoke with docker. See our [docker install](./docker.md) docs.
|
||||
- You can still use our legacy scripts to install and run Invoke. See the [legacy scripts](./legacy_scripts.md) docs.
|
||||
|
||||
## Need Help?
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
|
||||
|
||||
## Hardware
|
||||
|
||||
Hardware requirements vary significantly depending on model and image output size.
|
||||
|
||||
The requirements below are rough guidelines for best performance. GPUs with less VRAM typically still work, if a bit slower. Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
|
||||
Hardware requirements vary significantly depending on model and image output size. The requirements below are rough guidelines.
|
||||
|
||||
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
|
||||
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
|
||||
@@ -41,7 +39,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -49,19 +47,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
|
||||
=== "Windows"
|
||||
|
||||
- Install python with [an official installer].
|
||||
- Install python 3.11 with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- Install python with [an official installer].
|
||||
- Install python 3.11 with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
@@ -41,7 +41,7 @@ Nodes have a "Use Cache" option in their footer. This allows for performance imp
|
||||
|
||||
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
|
||||
|
||||
### Create Latent Noise
|
||||
### Noise
|
||||
|
||||
An initial noise tensor is necessary for the latent diffusion process. As a result, the Denoising node requires a noise node input.
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ If you'd prefer, you can also just download the whole node folder from the linke
|
||||
To use a community workflow, download the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||
|
||||
- Community Nodes
|
||||
+ [Anamorphic Tools](#anamorphic-tools)
|
||||
+ [Adapters-Linked](#adapters-linked-nodes)
|
||||
+ [Autostereogram](#autostereogram-nodes)
|
||||
+ [Average Images](#average-images)
|
||||
@@ -21,12 +20,9 @@ To use a community workflow, download the `.json` node graph file and load it in
|
||||
+ [Close Color Mask](#close-color-mask)
|
||||
+ [Clothing Mask](#clothing-mask)
|
||||
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
|
||||
+ [Curves](#curves)
|
||||
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
|
||||
+ [Enhance Detail](#enhance-detail)
|
||||
+ [Film Grain](#film-grain)
|
||||
+ [Flip Pose](#flip-pose)
|
||||
+ [Flux Ideal Size](#flux-ideal-size)
|
||||
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
|
||||
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
|
||||
+ [Grid to Gif](#grid-to-gif)
|
||||
@@ -65,13 +61,6 @@ To use a community workflow, download the `.json` node graph file and load it in
|
||||
- [Help](#help)
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Anamorphic Tools
|
||||
|
||||
**Description:** A set of nodes to perform anamorphic modifications to images, like lens blur, streaks, spherical distortion, and vignetting.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/anamorphic-tools
|
||||
|
||||
--------------------------------
|
||||
### Adapters Linked Nodes
|
||||
|
||||
@@ -143,13 +132,6 @@ Node Link: https://github.com/VeyDlin/clahe-node
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Curves
|
||||
|
||||
**Description:** Adjust an image's curve based on a user-defined string.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/curves-node
|
||||
|
||||
--------------------------------
|
||||
### Depth Map from Wavefront OBJ
|
||||
|
||||
@@ -180,20 +162,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/film-grain-node
|
||||
|
||||
--------------------------------
|
||||
### Flip Pose
|
||||
|
||||
**Description:** This node will flip an openpose image horizontally, recoloring it to make sure that it isn't facing the wrong direction. Note that it does not work with openpose hands.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/flip-pose-node
|
||||
|
||||
--------------------------------
|
||||
### Flux Ideal Size
|
||||
|
||||
**Description:** This node returns an ideal size to use for the first stage of a Flux image generation pipeline. Generating at the right size helps limit duplication and odd subject placement.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/flux-ideal-size
|
||||
|
||||
--------------------------------
|
||||
### Generative Grammar-Based Prompt Nodes
|
||||
|
||||
|
||||
BIN
installer/WinLongPathsEnabled.reg
Normal file
BIN
installer/WinLongPathsEnabled.reg
Normal file
Binary file not shown.
@@ -32,18 +32,18 @@ if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo -e "${BCYAN}CI environment detected${RESET}"
|
||||
echo
|
||||
else
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show HEAD
|
||||
echo
|
||||
|
||||
# If the classifiers are invalid, publishing to PyPI will fail but the build will succeed.
|
||||
# It's a fast check, do it early.
|
||||
echo "Checking pyproject classifiers..."
|
||||
python3 ./check_classifiers.py ../pyproject.toml
|
||||
echo
|
||||
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
@@ -77,8 +77,42 @@ fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python3 -m build --outdir ../dist/ ../.
|
||||
python3 -m build --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
echo "Building installer zip files for InvokeAI ${VERSION}..."
|
||||
echo
|
||||
|
||||
# get rid of any old ones
|
||||
rm -f *.zip
|
||||
rm -rf InvokeAI-Installer
|
||||
|
||||
# copy content
|
||||
mkdir InvokeAI-Installer
|
||||
for f in templates *.txt *.reg; do
|
||||
cp -r ${f} InvokeAI-Installer/
|
||||
done
|
||||
mkdir InvokeAI-Installer/lib
|
||||
cp lib/*.py InvokeAI-Installer/lib
|
||||
|
||||
# Install scripts
|
||||
# Mac/Linux
|
||||
cp install.sh.in InvokeAI-Installer/install.sh
|
||||
chmod a+x InvokeAI-Installer/install.sh
|
||||
|
||||
# Windows
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
FILENAME=InvokeAI-installer-$VERSION.zip
|
||||
|
||||
# Zip everything up
|
||||
zip -r ${FILENAME} InvokeAI-Installer
|
||||
|
||||
echo
|
||||
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
|
||||
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
|
||||
|
||||
# clean up, but only if we are not in a github action
|
||||
@@ -91,7 +125,9 @@ fi
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Setting GitHub action outputs..."
|
||||
echo "DIST_PATH=./dist/" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
exit 0
|
||||
128
installer/install.bat.in
Normal file
128
installer/install.bat.in
Normal file
@@ -0,0 +1,128 @@
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
@rem This script requires the user to install Python 3.10 or higher. All other
|
||||
@rem requirements are downloaded as needed.
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
@rem Config
|
||||
@rem The version in the next line is replaced by an up to date release number
|
||||
@rem when create_installer.sh is run. Change the release number there.
|
||||
set INSTRUCTIONS=https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
set TROUBLESHOOTING=https://invoke-ai.github.io/InvokeAI/help/FAQ/
|
||||
set PYTHON_URL=https://www.python.org/downloads/windows/
|
||||
set MINIMUM_PYTHON_VERSION=3.10.0
|
||||
set PYTHON_URL=https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
set err_msg=An error has occurred and the script could not continue.
|
||||
|
||||
@rem --------------------------- Intro -------------------------------
|
||||
echo This script will install InvokeAI and its dependencies.
|
||||
echo.
|
||||
echo BEFORE YOU START PLEASE MAKE SURE TO DO THE FOLLOWING
|
||||
echo 1. Install python 3.10 or 3.11. Python version 3.9 is no longer supported.
|
||||
echo 2. Double-click on the file WinLongPathsEnabled.reg in order to
|
||||
echo enable long path support on your system.
|
||||
echo 3. Install the Visual C++ core libraries.
|
||||
echo Please download and install the libraries from:
|
||||
echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
echo.
|
||||
echo See %INSTRUCTIONS% for more details.
|
||||
echo.
|
||||
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
|
||||
pause
|
||||
|
||||
@rem ---------------------------- check Python version ---------------
|
||||
echo ***** Checking and Updating Python *****
|
||||
|
||||
call python --version >.tmp1 2>.tmp2
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Please install Python 3.10-11. See %INSTRUCTIONS% for details.
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
for /f "tokens=2" %%i in (.tmp1) do set python_version=%%i
|
||||
if "%python_version%" == "" (
|
||||
set err_msg=No python was detected on your system. Please install Python version %MINIMUM_PYTHON_VERSION% or higher. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
call :compareVersions %MINIMUM_PYTHON_VERSION% %python_version%
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Your version of Python is too low. You need at least %MINIMUM_PYTHON_VERSION% but you have %python_version%. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem -------------- Install and Configure ---------------
|
||||
|
||||
call python .\lib\main.py
|
||||
pause
|
||||
exit /b
|
||||
|
||||
@rem ------------------------ Subroutines ---------------
|
||||
@rem routine to do comparison of semantic version numbers
|
||||
@rem found at https://stackoverflow.com/questions/15807762/compare-version-numbers-in-batch-file
|
||||
:compareVersions
|
||||
::
|
||||
:: Compares two version numbers and returns the result in the ERRORLEVEL
|
||||
::
|
||||
:: Returns 1 if version1 > version2
|
||||
:: 0 if version1 = version2
|
||||
:: -1 if version1 < version2
|
||||
::
|
||||
:: The nodes must be delimited by . or , or -
|
||||
::
|
||||
:: Nodes are normally strictly numeric, without a 0 prefix. A letter suffix
|
||||
:: is treated as a separate node
|
||||
::
|
||||
setlocal enableDelayedExpansion
|
||||
set "v1=%~1"
|
||||
set "v2=%~2"
|
||||
call :divideLetters v1
|
||||
call :divideLetters v2
|
||||
:loop
|
||||
call :parseNode "%v1%" n1 v1
|
||||
call :parseNode "%v2%" n2 v2
|
||||
if %n1% gtr %n2% exit /b 1
|
||||
if %n1% lss %n2% exit /b -1
|
||||
if not defined v1 if not defined v2 exit /b 0
|
||||
if not defined v1 exit /b -1
|
||||
if not defined v2 exit /b 1
|
||||
goto :loop
|
||||
|
||||
|
||||
:parseNode version nodeVar remainderVar
|
||||
for /f "tokens=1* delims=.,-" %%A in ("%~1") do (
|
||||
set "%~2=%%A"
|
||||
set "%~3=%%B"
|
||||
)
|
||||
exit /b
|
||||
|
||||
|
||||
:divideLetters versionVar
|
||||
for %%C in (a b c d e f g h i j k l m n o p q r s t u v w x y z) do set "%~1=!%~1:%%C=.%%C!"
|
||||
exit /b
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
echo The installer will exit now.
|
||||
pause
|
||||
exit /b
|
||||
|
||||
pause
|
||||
|
||||
:Trim
|
||||
SetLocal EnableDelayedExpansion
|
||||
set Params=%*
|
||||
for /f "tokens=1*" %%a in ("!Params!") do EndLocal & set %1=%%b
|
||||
exit /b
|
||||
40
installer/install.sh.in
Executable file
40
installer/install.sh.in
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
# make sure we are not already in a venv
|
||||
# (don't need to check status)
|
||||
deactivate >/dev/null 2>&1
|
||||
scriptdir=$(dirname "$0")
|
||||
cd $scriptdir
|
||||
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.10.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.11 python3.10 python3 python ; do
|
||||
if ppath=`which $candidate 2>/dev/null`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
|
||||
exec $PYTHON ./lib/main.py ${@}
|
||||
read -p "Press any key to exit"
|
||||
438
installer/lib/installer.py
Normal file
438
installer/lib/installer.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
InvokeAI installer script
|
||||
"""
|
||||
|
||||
import locale
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Tuple
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.10.0,<=3.11.100"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||
DOCS_URL = "https://invoke-ai.github.io/InvokeAI/"
|
||||
DISCORD_URL = "https://discord.gg/ZmtBAhwWhy"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
VERSION = "latest"
|
||||
|
||||
|
||||
def get_version_from_wheel_filename(wheel_filename: str) -> str:
|
||||
match = re.search(r"-(\d+\.\d+\.\d+)", wheel_filename)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return version
|
||||
else:
|
||||
raise ValueError(f"Could not extract version from wheel filename: {wheel_filename}")
|
||||
|
||||
|
||||
class Installer:
|
||||
"""
|
||||
Deploys an InvokeAI installation into a given path
|
||||
"""
|
||||
|
||||
reqs: list[str] = INSTALLER_REQS
|
||||
|
||||
def __init__(self) -> None:
|
||||
if os.getenv("VIRTUAL_ENV") is not None:
|
||||
print("A virtual environment is already activated. Please 'deactivate' before installation.")
|
||||
sys.exit(-1)
|
||||
self.bootstrap()
|
||||
self.available_releases = get_github_releases()
|
||||
|
||||
def mktemp_venv(self) -> TemporaryDirectory[str]:
|
||||
"""
|
||||
Creates a temporary virtual environment for the installer itself
|
||||
|
||||
:return: path to the created virtual environment directory
|
||||
:rtype: TemporaryDirectory
|
||||
"""
|
||||
|
||||
# Cleaning up temporary directories on Windows results in a race condition
|
||||
# and a stack trace.
|
||||
# `ignore_cleanup_errors` was only added in Python 3.10
|
||||
if OS == "Windows" and int(platform.python_version_tuple()[1]) >= 10:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX, ignore_cleanup_errors=True)
|
||||
else:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX)
|
||||
|
||||
venv.create(venv_dir.name, with_pip=True)
|
||||
self.venv_dir = venv_dir
|
||||
set_sys_path(Path(venv_dir.name))
|
||||
|
||||
return venv_dir
|
||||
|
||||
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory[str] | None:
|
||||
"""
|
||||
Bootstrap the installer venv with packages required at install time
|
||||
"""
|
||||
|
||||
print("Initializing the installer. This may take a minute - please wait...")
|
||||
|
||||
venv_dir = self.mktemp_venv()
|
||||
pip = get_pip_from_venv(Path(venv_dir.name))
|
||||
|
||||
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
|
||||
cmd.extend(self.reqs)
|
||||
|
||||
try:
|
||||
# upgrade pip to the latest version to avoid a confusing message
|
||||
res = upgrade_pip(Path(venv_dir.name))
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
# run the install prerequisites installation
|
||||
res = subprocess.check_output(cmd).decode()
|
||||
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
return venv_dir
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
|
||||
def app_venv(self, venv_parent: Path) -> Path:
|
||||
"""
|
||||
Create a virtualenv for the InvokeAI installation
|
||||
"""
|
||||
|
||||
venv_dir = venv_parent / ".venv"
|
||||
|
||||
# Prefer to copy python executables
|
||||
# so that updates to system python don't break InvokeAI
|
||||
try:
|
||||
venv.create(venv_dir, with_pip=True)
|
||||
# If installing over an existing environment previously created with symlinks,
|
||||
# the executables will fail to copy. Keep symlinks in that case
|
||||
except shutil.SameFileError:
|
||||
venv.create(venv_dir, with_pip=True, symlinks=True)
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(
|
||||
self,
|
||||
root: str = "~/invokeai",
|
||||
yes_to_all: bool = False,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Install the InvokeAI application into the given runtime path
|
||||
|
||||
Args:
|
||||
root: Destination path for the installation
|
||||
yes_to_all: Accept defaults to all questions
|
||||
find_links: A local directory to search for requirement wheels before going to remote indexes
|
||||
wheel: A wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
if wheel:
|
||||
messages.installing_from_wheel(wheel.name)
|
||||
version = get_version_from_wheel_filename(wheel.name)
|
||||
else:
|
||||
messages.welcome(self.available_releases)
|
||||
version = messages.choose_version(self.available_releases)
|
||||
|
||||
auto_dest = Path(os.environ.get("INVOKEAI_ROOT", root)).expanduser().resolve()
|
||||
destination = auto_dest if yes_to_all else messages.dest_path(root)
|
||||
if destination is None:
|
||||
print("Could not find or create the destination directory. Installation cancelled.")
|
||||
sys.exit(0)
|
||||
|
||||
# create the venv for the app
|
||||
self.venv = self.app_venv(venv_parent=destination)
|
||||
|
||||
self.instance = InvokeAiInstance(runtime=destination, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(extra_index_url, optional_modules, find_links, wheel)
|
||||
|
||||
# install the launch/update scripts into the runtime directory
|
||||
self.instance.install_user_scripts()
|
||||
|
||||
message = f"""
|
||||
*** Installation Successful ***
|
||||
|
||||
To start the application, run:
|
||||
{destination}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
For more information, troubleshooting and support, visit our docs at:
|
||||
{DOCS_URL}
|
||||
|
||||
Join the community on Discord:
|
||||
{DISCORD_URL}
|
||||
"""
|
||||
print(message)
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
The virtual environment *may* reside within the runtime directory.
|
||||
A single runtime directory *may* be shared by multiple virtual environments, though this isn't currently tested or supported.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str = "stable") -> None:
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
self.version = version
|
||||
|
||||
set_sys_path(venv)
|
||||
os.environ["INVOKEAI_ROOT"] = str(self.runtime.expanduser().resolve())
|
||||
os.environ["VIRTUAL_ENV"] = str(self.venv.expanduser().resolve())
|
||||
upgrade_pip(venv)
|
||||
|
||||
def get(self) -> tuple[Path, Path]:
|
||||
"""
|
||||
Get the location of the virtualenv directory for this installation
|
||||
|
||||
:return: Paths of the runtime and the venv directory
|
||||
:rtype: tuple[Path, Path]
|
||||
"""
|
||||
|
||||
return (self.runtime, self.venv)
|
||||
|
||||
def install(
|
||||
self,
|
||||
extra_index_url: Optional[str] = None,
|
||||
optional_modules: Optional[str] = None,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
):
|
||||
"""Install the package from PyPi or a wheel, if provided.
|
||||
|
||||
Args:
|
||||
extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
optional_modules: optional modules to install using "[module1,module2]" format.
|
||||
find_links: path to a directory containing wheels to be searched prior to going to the internet
|
||||
wheel: a wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
# not currently used, but may be useful for "install most recent version" option
|
||||
if self.version == "prerelease":
|
||||
version = None
|
||||
pre_flag = "--pre"
|
||||
elif self.version == "stable":
|
||||
version = None
|
||||
pre_flag = None
|
||||
else:
|
||||
version = self.version
|
||||
pre_flag = None
|
||||
|
||||
src = "invokeai"
|
||||
if optional_modules:
|
||||
src += optional_modules
|
||||
if version:
|
||||
src += f"=={version}"
|
||||
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
|
||||
from plumbum import FG, ProcessExecutionError, local
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
# Uninstall xformers if it is present; the correct version of it will be reinstalled if needed
|
||||
_ = pip["uninstall", "-yqq", "xformers"] & FG
|
||||
|
||||
pipeline = pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--force-reinstall",
|
||||
"--use-pep517",
|
||||
str(src) if not wheel else str(wheel),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre_flag if not wheel else None, # Ignore the flag if we are installing a wheel
|
||||
]
|
||||
|
||||
try:
|
||||
_ = pipeline & FG
|
||||
except ProcessExecutionError as e:
|
||||
print(f"Error: {e}")
|
||||
print(
|
||||
"Could not install InvokeAI. Please try downloading the latest version of the installer and install again."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
Copy the launch and update scripts to the runtime dir
|
||||
"""
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
|
||||
|
||||
### Utility functions ###
|
||||
|
||||
|
||||
def get_pip_from_venv(venv_path: Path) -> str:
|
||||
"""
|
||||
Given a path to a virtual environment, get the absolute path to the `pip` executable
|
||||
in a cross-platform fashion. Does not validate that the pip executable
|
||||
actually exists in the virtualenv.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
:return: Absolute path to the pip executable
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
def upgrade_pip(venv_path: Path) -> str | None:
|
||||
"""
|
||||
Upgrade the pip executable in the given virtual environment
|
||||
"""
|
||||
|
||||
python = "Scripts\\python.exe" if OS == "Windows" else "bin/python"
|
||||
python = str(venv_path.expanduser().resolve() / python)
|
||||
|
||||
try:
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||
encoding=locale.getpreferredencoding()
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_sys_path(venv_path: Path) -> None:
|
||||
"""
|
||||
Given a path to a virtual environment, set the sys.path, in a cross-platform fashion,
|
||||
such that packages from the given venv may be imported in the current process.
|
||||
Ensure that the packages from system environment are not visible (emulate
|
||||
the virtual env 'activate' script) - this doesn't work on Windows yet.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
"""
|
||||
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
# add the site-packages location to the venv
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_github_releases() -> tuple[list[str], list[str]] | None:
|
||||
"""
|
||||
Query Github for published (pre-)release versions.
|
||||
Return a tuple where the first element is a list of stable releases and the second element is a list of pre-releases.
|
||||
Return None if the query fails for any reason.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
## get latest releases using github api
|
||||
url = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
releases: list[str] = []
|
||||
pre_releases: list[str] = []
|
||||
try:
|
||||
res = requests.get(url)
|
||||
res.raise_for_status()
|
||||
tag_info = res.json()
|
||||
for tag in tag_info:
|
||||
if not tag["prerelease"]:
|
||||
releases.append(tag["tag_name"].lstrip("v"))
|
||||
else:
|
||||
pre_releases.append(tag["tag_name"].lstrip("v"))
|
||||
except requests.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Could not fetch version information from GitHub. Please check your network connection and try again.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("An unexpected error occurred while trying to fetch version information from GitHub. Please try again.")
|
||||
return
|
||||
|
||||
releases.sort(reverse=True)
|
||||
pre_releases.sort(reverse=True)
|
||||
|
||||
return releases, pre_releases
|
||||
|
||||
|
||||
def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
This is only applicable to Windows and Linux, since PyTorch does not
|
||||
offer accelerated builds for macOS.
|
||||
|
||||
Prefer CUDA-enabled wheels if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
||||
|
||||
A NoneType return means just go to PyPi.
|
||||
|
||||
:return: tuple consisting of (extra index url or None, optional modules to load or None)
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import GpuType, select_gpu
|
||||
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
|
||||
|
||||
url = None
|
||||
optional_modules: str | None = None
|
||||
if OS == "Linux":
|
||||
if device == GpuType.ROCM:
|
||||
url = "https://download.pytorch.org/whl/rocm6.1"
|
||||
elif device == GpuType.CPU:
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
elif device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif OS == "Windows":
|
||||
if device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif device.value == "cpu":
|
||||
# CPU uses the default PyPi index, no optional modules
|
||||
pass
|
||||
elif OS == "Darwin":
|
||||
# macOS uses the default PyPi index, no optional modules
|
||||
pass
|
||||
|
||||
# Fall back to defaults
|
||||
|
||||
return (url, optional_modules)
|
||||
57
installer/lib/main.py
Normal file
57
installer/lib/main.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
InvokeAI Installer
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from installer import Installer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
dest="root",
|
||||
type=str,
|
||||
help="Destination path for installation",
|
||||
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-y",
|
||||
"--yes",
|
||||
"--yes-to-all",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help="Assume default answers to all questions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--find-links",
|
||||
dest="find_links",
|
||||
help="Specifies a directory of local wheel files to be searched prior to searching the online repositories.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wheel",
|
||||
dest="wheel",
|
||||
help="Specifies a wheel for the InvokeAI package. Used for troubleshooting or testing prereleases.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
342
installer/lib/messages.py
Normal file
342
installer/lib/messages.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
Installer user interaction
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def welcome(available_releases: tuple[list[str], list[str]] | None = None) -> None:
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) is not None:
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
if available_releases is not None:
|
||||
latest_stable = available_releases[0][0]
|
||||
last_pre = available_releases[1][0]
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Latest stable release (recommended): [b bright_white]{latest_stable}", justify="center"
|
||||
)
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Last published pre-release version: [b bright_white]{last_pre}", justify="center"
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
Panel(
|
||||
title="[bold wheat1]Welcome to the InvokeAI Installer",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=Style(bgcolor="grey23", color="orange1"),
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def installing_from_wheel(wheel_filename: str) -> None:
|
||||
"""Display a message about installing from a wheel"""
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield Text.from_markup(f"You are installing from a wheel file: [bold]{wheel_filename}\n")
|
||||
yield Text.from_markup(
|
||||
"[bold orange3]If you are not sure why you are doing this, you should cancel and install InvokeAI normally."
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
title="Installing from Wheel",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
should_proceed = Confirm.ask("Do you want to proceed?")
|
||||
|
||||
if not should_proceed:
|
||||
console.print("Installation cancelled.")
|
||||
exit()
|
||||
|
||||
|
||||
def choose_version(available_releases: tuple[list[str], list[str]] | None = None) -> str:
|
||||
"""
|
||||
Prompt the user to choose an Invoke version to install
|
||||
"""
|
||||
|
||||
# short circuit if we couldn't get a version list
|
||||
# still try to install the latest stable version
|
||||
if available_releases is None:
|
||||
return "stable"
|
||||
|
||||
console.print(":grey_question: [orange3]Please choose an Invoke version to install.")
|
||||
|
||||
choices = available_releases[0] + available_releases[1]
|
||||
|
||||
response = prompt(
|
||||
message=f" <Enter> to install the recommended release ({choices[0]}). <Tab> or type to pick a version: ",
|
||||
complete_while_typing=True,
|
||||
completer=FuzzyWordCompleter(choices),
|
||||
)
|
||||
console.print(f" Version {choices[0] if response == '' else response} will be installed.")
|
||||
|
||||
console.line()
|
||||
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
print(" Is this location correct?")
|
||||
default = False
|
||||
else:
|
||||
print(f":file_folder: InvokeAI will be installed in {dest}")
|
||||
default = True
|
||||
|
||||
dest_confirmed = Confirm.ask(" Please confirm:", default=default)
|
||||
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def dest_path(dest: Optional[str | Path] = None) -> Path | None:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
:param dest: a filesystem path, defaults to None
|
||||
:type dest: str, optional
|
||||
:return: absolute path to the created installation directory
|
||||
:rtype: Path
|
||||
"""
|
||||
|
||||
if dest is not None:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
dest_confirmed = False
|
||||
|
||||
while not dest_confirmed:
|
||||
browse_start = (dest or Path.cwd()).expanduser().resolve()
|
||||
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [str(browse_start)], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
console.line()
|
||||
|
||||
console.print(f":grey_question: [orange3]Please select the install destination:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
vi_mode=True,
|
||||
complete_while_typing=True,
|
||||
# Test that this is not needed on Windows
|
||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||
)
|
||||
prev_dest = dest
|
||||
dest = Path(selected)
|
||||
|
||||
console.line()
|
||||
|
||||
dest_confirmed = confirm_install(dest.expanduser().resolve())
|
||||
|
||||
if not dest_confirmed:
|
||||
dest = prev_dest
|
||||
|
||||
dest = dest.expanduser().resolve()
|
||||
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
else:
|
||||
console.rule("Goodbye!")
|
||||
|
||||
|
||||
class GpuType(Enum):
|
||||
CUDA_WITH_XFORMERS = "xformers"
|
||||
CUDA = "cuda"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
"""
|
||||
Prompt the user to select the GPU driver
|
||||
"""
|
||||
|
||||
if ARCH == "arm64" and OS != "Darwin":
|
||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||
return GpuType.CPU
|
||||
|
||||
nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 3060 or newer GPU using CUDA",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
vintage_nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 20xx or older GPU using CUDA+xFormers",
|
||||
GpuType.CUDA_WITH_XFORMERS,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU using ROCm",
|
||||
GpuType.ROCM,
|
||||
)
|
||||
cpu = (
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
|
||||
options = []
|
||||
if OS == "Windows":
|
||||
options = [nvidia, vintage_nvidia, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, vintage_nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
options = [cpu]
|
||||
|
||||
if len(options) == 1:
|
||||
return options[0][1]
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
f"Detected the [gold1]{OS}-{ARCH}[/] platform",
|
||||
"",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/installation/requirements/[/] to ensure your system meets the minimum requirements.",
|
||||
"",
|
||||
"[red3]🠶[/] [b]Your GPU drivers must be correctly installed before using InvokeAI![/] [red3]🠴[/]",
|
||||
]
|
||||
),
|
||||
"",
|
||||
"Please select the type of GPU installed in your computer.",
|
||||
Panel(
|
||||
"\n".join([f"[dark_goldenrod b i]{i}[/] [dark_red]🢒[/]{opt[0]}" for (i, opt) in options.items()]),
|
||||
box=box.MINIMAL,
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = prompt(
|
||||
"Please make your selection: ",
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n in options.keys(), error_message="Please select one the above options"
|
||||
),
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
|
||||
|
||||
def simple_banner(message: str) -> None:
|
||||
"""
|
||||
A simple banner with a message, defined here for styling consistency
|
||||
|
||||
:param message: The message to display
|
||||
:type message: str
|
||||
"""
|
||||
|
||||
console.rule(message)
|
||||
|
||||
|
||||
# TODO this does not yet work correctly
|
||||
def windows_long_paths_registry() -> None:
|
||||
"""
|
||||
Display a message about applying the Windows long paths registry fix
|
||||
"""
|
||||
|
||||
with open(str(Path(__file__).parent / "WinLongPathsEnabled.reg"), "r", encoding="utf-16le") as code:
|
||||
syntax = Syntax(code.read(), line_numbers=True, lexer="regedit")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"We will now apply a registry fix to enable long paths on Windows. InvokeAI needs this to function correctly. We are asking your permission to modify the Windows Registry on your behalf.",
|
||||
"",
|
||||
"This is the change that will be applied:",
|
||||
str(syntax),
|
||||
]
|
||||
)
|
||||
),
|
||||
title="Windows Long Paths registry fix",
|
||||
box=box.HORIZONTALS,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
return
|
||||
return text
|
||||
52
installer/readme.txt
Normal file
52
installer/readme.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Preparations:
|
||||
|
||||
You will need to install Python 3.10 or higher for this installer
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.10.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
pip --version
|
||||
|
||||
You should get a message that indicates that the pip package
|
||||
installer was derived from Python 3.10 or 3.11. For example:
|
||||
"pip 22.0.1 from /usr/bin/pip (python 3.10)"
|
||||
|
||||
Long Paths on Windows:
|
||||
|
||||
If you are on Windows, you will need to enable Windows Long Paths to
|
||||
run InvokeAI successfully. If you're not sure what this is, you
|
||||
almost certainly need to do this.
|
||||
|
||||
Simply double-click the "WinLongPathsEnabled.reg" file located in
|
||||
this directory, and approve the Windows warnings. Note that you will
|
||||
need to have admin privileges in order to do this.
|
||||
|
||||
Launching the installer:
|
||||
|
||||
Windows: double-click the 'install.bat' file (while keeping it inside
|
||||
the InvokeAI-Installer folder).
|
||||
|
||||
Linux and Mac: Please open the terminal application and run
|
||||
'./install.sh' (while keeping it inside the InvokeAI-Installer
|
||||
folder).
|
||||
|
||||
The installer will create a directory of your choice and install the
|
||||
InvokeAI application within it. This directory contains everything you need to run
|
||||
invokeai. Once InvokeAI is up and running, you may delete the
|
||||
InvokeAI-Installer folder at your convenience.
|
||||
|
||||
For more information, please see
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
54
installer/templates/invoke.bat.in
Normal file
54
installer/templates/invoke.bat.in
Normal file
@@ -0,0 +1,54 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
setlocal
|
||||
|
||||
call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Open the developer console
|
||||
echo 3. Command-line help
|
||||
echo Q - Quit
|
||||
echo.
|
||||
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
echo.
|
||||
set /P choice="Please enter 1-4, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
goto start
|
||||
|
||||
endlocal
|
||||
pause
|
||||
|
||||
:ending
|
||||
exit /b
|
||||
87
installer/templates/invoke.sh.in
Normal file
87
installer/templates/invoke.sh.in
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
# 2. the .venv is also located in the runtime directory and is named exactly that
|
||||
#
|
||||
# If both of the above are not true, this script will likely not work as intended.
|
||||
# Activate the virtual environment and run `invoke.py` directly.
|
||||
####
|
||||
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
|
||||
# Stash the CLI args - when we prompt for user input, `$@` is overwritten
|
||||
PARAMS=$@
|
||||
|
||||
# This setting allows torch to fall back to CPU for operations that are not supported by MPS on macOS.
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest\n\n"
|
||||
read -p "Please enter 1-4, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke, and for checking if the user is in the developer console
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
do_line_input
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
printf "Press ^D to exit\n"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
@@ -10,7 +10,6 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards.boards_default import BoardService
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
@@ -24,10 +23,6 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
|
||||
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
|
||||
SqliteModelRelationshipRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
|
||||
from invokeai.app.services.names.names_default import SimpleNameService
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
@@ -41,15 +36,7 @@ from invokeai.app.services.style_preset_images.style_preset_images_disk import S
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -96,7 +83,6 @@ class ApiDependencies:
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
workflow_thumbnails_folder = config.workflow_thumbnails_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -113,25 +99,10 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](
|
||||
output_folder / "tensors",
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
output_folder / "conditioning",
|
||||
safe_globals=[
|
||||
ConditioningFieldData,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
@@ -141,8 +112,6 @@ class ApiDependencies:
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
@@ -151,8 +120,6 @@ class ApiDependencies:
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -169,8 +136,6 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
model_relationships=model_relationships,
|
||||
model_relationship_records=model_relationship_records,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
@@ -182,8 +147,6 @@ class ApiDependencies:
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
client_state_persistence=client_state_persistence,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
invokeai_metadata: str | None
|
||||
invokeai_workflow: str | None
|
||||
invokeai_graph: str | None
|
||||
|
||||
|
||||
def extract_metadata_from_image(
|
||||
pil_image: Image.Image,
|
||||
invokeai_metadata_override: str | None,
|
||||
invokeai_workflow_override: str | None,
|
||||
invokeai_graph_override: str | None,
|
||||
logger: logging.Logger,
|
||||
) -> ExtractedMetadata:
|
||||
"""
|
||||
Extracts the "invokeai_metadata", "invokeai_workflow", and "invokeai_graph" data embedded in the PIL Image.
|
||||
|
||||
These items are stored as stringified JSON in the image file's metadata, so we need to do some parsing to validate
|
||||
them. Once parsed, the values are returned as they came (as strings), or None if they are not present or invalid.
|
||||
|
||||
In some situations, we may prefer to override the values extracted from the image file with some other values.
|
||||
|
||||
For example, when uploading an image via API, the client can optionally provide the metadata directly in the request,
|
||||
as opposed to embedding it in the image file. In this case, the client-provided metadata will be used instead of the
|
||||
metadata embedded in the image file.
|
||||
|
||||
Args:
|
||||
pil_image: The PIL Image object.
|
||||
invokeai_metadata_override: The metadata override provided by the client.
|
||||
invokeai_workflow_override: The workflow override provided by the client.
|
||||
invokeai_graph_override: The graph override provided by the client.
|
||||
logger: The logger to use for debug logging.
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata: The extracted metadata, workflow, and graph.
|
||||
"""
|
||||
|
||||
# The fallback value for metadata is None.
|
||||
stringified_metadata: str | None = None
|
||||
|
||||
# Use the metadata override if provided, else attempt to extract it from the image file.
|
||||
metadata_raw = invokeai_metadata_override or pil_image.info.get("invokeai_metadata", None)
|
||||
|
||||
# If the metadata is present in the image file, we will attempt to parse it as JSON. When we create images,
|
||||
# we always store metadata as a stringified JSON dict. So, we expect it to be a string here.
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
# Must be a JSON string
|
||||
metadata_parsed = json.loads(metadata_raw)
|
||||
# Must be a dict
|
||||
if isinstance(metadata_parsed, dict):
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_metadata = metadata_raw
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse metadata for uploaded image, {e}")
|
||||
pass
|
||||
|
||||
# We expect the workflow, if embedded in the image, to be a JSON-stringified WorkflowWithoutID. We will store it
|
||||
# as a string.
|
||||
workflow_raw: str | None = invokeai_workflow_override or pil_image.info.get("invokeai_workflow", None)
|
||||
|
||||
# The fallback value for workflow is None.
|
||||
stringified_workflow: str | None = None
|
||||
|
||||
# If the workflow is present in the image file, we will attempt to parse it as JSON. When we create images, we
|
||||
# always store workflows as a stringified JSON WorkflowWithoutID. So, we expect it to be a string here.
|
||||
if isinstance(workflow_raw, str):
|
||||
try:
|
||||
# Validate the workflow JSON before storing it
|
||||
WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_workflow = workflow_raw
|
||||
except Exception:
|
||||
logger.debug("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
# We expect the workflow, if embedded in the image, to be a JSON-stringified Graph. We will store it as a
|
||||
# string.
|
||||
graph_raw: str | None = invokeai_graph_override or pil_image.info.get("invokeai_graph", None)
|
||||
|
||||
# The fallback value for graph is None.
|
||||
stringified_graph: str | None = None
|
||||
|
||||
# If the graph is present in the image file, we will attempt to parse it as JSON. When we create images, we
|
||||
# always store graphs as a stringified JSON Graph. So, we expect it to be a string here.
|
||||
if isinstance(graph_raw, str):
|
||||
try:
|
||||
# TODO(psyche): Due to pydantic's handling of None values, it is possible for the graph to fail validation,
|
||||
# even if it is a direct dump of a valid graph. Node fields in the graph are allowed to have be unset if
|
||||
# they have incoming connections, but something about the ser/de process cannot adequately handle this.
|
||||
#
|
||||
# In lieu of fixing the graph validation, we will just do a simple check here to see if the graph is dict
|
||||
# with the correct keys. This is not a perfect solution, but it should be good enough for now.
|
||||
|
||||
# FIX ME: Validate the graph JSON before storing it
|
||||
# Graph.model_validate_json(graph_raw)
|
||||
|
||||
# Crappy workaround to validate JSON
|
||||
graph_parsed = json.loads(graph_raw)
|
||||
if not isinstance(graph_parsed, dict):
|
||||
raise ValueError("Not a dict")
|
||||
if not isinstance(graph_parsed.get("nodes", None), dict):
|
||||
raise ValueError("'nodes' is not a dict")
|
||||
if not isinstance(graph_parsed.get("edges", None), list):
|
||||
raise ValueError("'edges' is not a list")
|
||||
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_graph = graph_raw
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse graph for uploaded image, {e}")
|
||||
pass
|
||||
|
||||
return ExtractedMetadata(
|
||||
invokeai_metadata=stringified_metadata, invokeai_workflow=stringified_workflow, invokeai_graph=stringified_graph
|
||||
)
|
||||
@@ -1,7 +1,8 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from importlib.metadata import distributions
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from platform import python_version
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -11,7 +12,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.util.logging import logging
|
||||
@@ -43,6 +43,24 @@ class AppVersion(BaseModel):
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
accelerate: str = Field(description="accelerate version")
|
||||
compel: str = Field(description="compel version")
|
||||
cuda: Optional[str] = Field(description="CUDA version")
|
||||
diffusers: str = Field(description="diffusers version")
|
||||
numpy: str = Field(description="Numpy version")
|
||||
opencv: str = Field(description="OpenCV version")
|
||||
onnx: str = Field(description="ONNX version")
|
||||
pillow: str = Field(description="Pillow (PIL) version")
|
||||
python: str = Field(description="Python version")
|
||||
torch: str = Field(description="PyTorch version")
|
||||
torchvision: str = Field(description="PyTorch Vision version")
|
||||
transformers: str = Field(description="transformers version")
|
||||
xformers: Optional[str] = Field(description="xformers version")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""App Config Response"""
|
||||
|
||||
@@ -57,23 +75,31 @@ async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=dict[str, str])
|
||||
async def get_app_deps() -> dict[str, str]:
|
||||
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
|
||||
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
|
||||
async def get_app_deps() -> AppDependencyVersions:
|
||||
try:
|
||||
cuda = torch.version.cuda or "N/A"
|
||||
except Exception:
|
||||
cuda = "N/A"
|
||||
|
||||
deps["CUDA"] = cuda
|
||||
|
||||
sorted_deps = dict(sorted(deps.items(), key=lambda item: item[0].lower()))
|
||||
|
||||
return sorted_deps
|
||||
xformers = version("xformers")
|
||||
except PackageNotFoundError:
|
||||
xformers = None
|
||||
return AppDependencyVersions(
|
||||
accelerate=version("accelerate"),
|
||||
compel=version("compel"),
|
||||
cuda=torch.version.cuda,
|
||||
diffusers=version("diffusers"),
|
||||
numpy=version("numpy"),
|
||||
opencv=version("opencv-python"),
|
||||
onnx=version("onnx"),
|
||||
pillow=version("pillow"),
|
||||
python=python_version(),
|
||||
torch=torch.version.__version__,
|
||||
torchvision=version("torchvision"),
|
||||
transformers=version("transformers"),
|
||||
xformers=xformers,
|
||||
)
|
||||
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config_() -> AppConfig:
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["lama", "tile", "cv2", "color"] # TODO: add mosaic back
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
@@ -95,21 +121,6 @@ async def get_config_() -> AppConfig:
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
"""InvokeAI App Config with model fields set"""
|
||||
|
||||
set_fields: set[str] = Field(description="The set fields")
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
|
||||
)
|
||||
async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
|
||||
config = get_config()
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="add_image_to_board",
|
||||
@@ -14,26 +23,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
) -> AddImagesToBoardResult:
|
||||
):
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
@@ -45,25 +45,14 @@ async def add_image_to_board(
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
):
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
@@ -83,25 +72,16 @@ async def add_images_to_board(
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
added_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id,
|
||||
image_name=image_name,
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
added_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
@@ -120,20 +100,13 @@ async def remove_images_from_board(
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
removed_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
removed_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.videos_common import AddVideosToBoardResult, RemoveVideosFromBoardResult
|
||||
|
||||
board_videos_router = APIRouter(prefix="/v1/board_videos", tags=["boards"])
|
||||
|
||||
|
||||
@board_videos_router.post(
|
||||
"/batch",
|
||||
operation_id="add_videos_to_board",
|
||||
responses={
|
||||
201: {"description": "Videos were added to board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddVideosToBoardResult,
|
||||
)
|
||||
async def add_videos_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
video_ids: list[str] = Body(description="The ids of the videos to add", embed=True),
|
||||
) -> AddVideosToBoardResult:
|
||||
"""Adds a list of videos to a board"""
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@board_videos_router.post(
|
||||
"/batch/delete",
|
||||
operation_id="remove_videos_from_board",
|
||||
responses={
|
||||
201: {"description": "Videos were removed from board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveVideosFromBoardResult,
|
||||
)
|
||||
async def remove_videos_from_board(
|
||||
video_ids: list[str] = Body(description="The ids of the videos to remove", embed=True),
|
||||
) -> RemoveVideosFromBoardResult:
|
||||
"""Removes a list of videos from their board, if they had one"""
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
@@ -88,9 +87,7 @@ async def delete_board(
|
||||
try:
|
||||
if include_images is True:
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id=board_id,
|
||||
categories=None,
|
||||
is_intermediate=None,
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
@@ -101,9 +98,7 @@ async def delete_board(
|
||||
)
|
||||
else:
|
||||
deleted_board_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id=board_id,
|
||||
categories=None,
|
||||
is_intermediate=None,
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
return DeleteBoardResult(
|
||||
@@ -146,15 +141,11 @@ async def list_boards(
|
||||
response_model=list[str],
|
||||
)
|
||||
async def list_all_board_image_names(
|
||||
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
|
||||
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
) -> list[str]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
return image_names
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"])
|
||||
|
||||
|
||||
@client_state_router.get(
|
||||
"/{queue_id}/get_by_key",
|
||||
operation_id="get_client_state_by_key",
|
||||
response_model=str | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> str | None:
|
||||
"""Gets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
"/{queue_id}/set_by_key",
|
||||
operation_id="set_client_state",
|
||||
response_model=str,
|
||||
)
|
||||
async def set_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: str = Body(..., description="Stringified value to set"),
|
||||
) -> str:
|
||||
"""Sets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
"/{queue_id}/delete",
|
||||
operation_id="delete_client_state",
|
||||
responses={204: {"description": "Client state deleted"}},
|
||||
)
|
||||
async def delete_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> None:
|
||||
"""Deletes the client state"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting client state")
|
||||
@@ -1,34 +1,23 @@
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import (
|
||||
DeleteImagesResult,
|
||||
ImageDTO,
|
||||
ImageUrlsDTO,
|
||||
StarredImagesResult,
|
||||
UnstarredImagesResult,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
@@ -37,19 +26,6 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ResizeToDimensions(BaseModel):
|
||||
width: int = Field(..., gt=0)
|
||||
height: int = Field(..., gt=0)
|
||||
|
||||
MAX_SIZE: ClassVar[int] = 4096 * 4096
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_total_output_size(self):
|
||||
if self.width * self.height > self.MAX_SIZE:
|
||||
raise ValueError(f"Max total output size for resizing is {self.MAX_SIZE} pixels")
|
||||
return self
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/upload",
|
||||
operation_id="upload_image",
|
||||
@@ -69,58 +45,52 @@ async def upload_image(
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
resize_to: Optional[str] = Body(
|
||||
default=None,
|
||||
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
|
||||
examples=['"[1024,1024]"'],
|
||||
),
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
description="The metadata to associate with the image, must be a stringified JSON dict",
|
||||
embed=True,
|
||||
metadata: Optional[JsonValue] = Body(
|
||||
default=None, description="The metadata to associate with the image", embed=True
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
_metadata = None
|
||||
_workflow = None
|
||||
_graph = None
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
if crop_visible:
|
||||
try:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to crop image")
|
||||
# TODO: retain non-invokeai metadata on upload?
|
||||
# attempt to parse metadata from image
|
||||
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
|
||||
if isinstance(metadata_raw, str):
|
||||
_metadata = metadata_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
if resize_to:
|
||||
try:
|
||||
dims = json.loads(resize_to)
|
||||
resize_dims = ResizeToDimensions(**dims)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
|
||||
# attempt to parse workflow from image
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if isinstance(workflow_raw, str):
|
||||
_workflow = workflow_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
try:
|
||||
# heuristic_resize_fast expects an RGB or RGBA image
|
||||
pil_rgba = pil_image.convert("RGBA")
|
||||
np_image = pil_to_np(pil_rgba)
|
||||
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
|
||||
pil_image = np_to_pil(np_image)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to resize image")
|
||||
|
||||
extracted_metadata = extract_metadata_from_image(
|
||||
pil_image=pil_image,
|
||||
invokeai_metadata_override=metadata,
|
||||
invokeai_workflow_override=None,
|
||||
invokeai_graph_override=None,
|
||||
logger=ApiDependencies.invoker.services.logger,
|
||||
)
|
||||
# attempt to extract graph from image
|
||||
graph_raw = pil_image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph_raw, str):
|
||||
_graph = graph_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
|
||||
pass
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
@@ -129,9 +99,9 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=extracted_metadata.invokeai_metadata,
|
||||
workflow=extracted_metadata.invokeai_workflow,
|
||||
graph=extracted_metadata.invokeai_graph,
|
||||
metadata=_metadata,
|
||||
workflow=_workflow,
|
||||
graph=_graph,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
@@ -144,46 +114,18 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
class ImageUploadEntry(BaseModel):
|
||||
image_dto: ImageDTO = Body(description="The image DTO")
|
||||
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
|
||||
|
||||
|
||||
@images_router.post("/", operation_id="create_image_upload_entry")
|
||||
async def create_image_upload_entry(
|
||||
width: int = Body(description="The width of the image"),
|
||||
height: int = Body(description="The height of the image"),
|
||||
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
|
||||
) -> ImageUploadEntry:
|
||||
"""Uploads an image from a URL, not implemented"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> DeleteImagesResult:
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
|
||||
@images_router.delete("/intermediates", operation_id="clear_intermediates")
|
||||
async def clear_intermediates() -> int:
|
||||
@@ -395,52 +337,23 @@ async def list_image_dtos(
|
||||
return image_dtos
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesResult:
|
||||
) -> DeleteImagesFromListResult:
|
||||
try:
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
|
||||
async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
"""Deletes all images that are uncategorized"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id="none", categories=None, is_intermediate=None
|
||||
)
|
||||
|
||||
try:
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
deleted_images: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
deleted_images.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
@@ -449,50 +362,36 @@ class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> StarredImagesResult:
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
starred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=True)
|
||||
)
|
||||
starred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return StarredImagesResult(
|
||||
starred_images=list(starred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> UnstarredImagesResult:
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
unstarred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=False)
|
||||
)
|
||||
unstarred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return UnstarredImagesResult(
|
||||
unstarred_images=list(unstarred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -563,61 +462,3 @@ async def get_bulk_download_item(
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image names")
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/images_by_names",
|
||||
operation_id="get_images_by_names",
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
|
||||
try:
|
||||
image_service = ApiDependencies.invoker.services.images
|
||||
|
||||
# Fetch DTOs preserving the order of requested names
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image DTOs")
|
||||
|
||||
@@ -28,10 +28,12 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
@@ -41,7 +43,6 @@ from invokeai.backend.model_manager.starter_models import (
|
||||
STARTER_BUNDLES,
|
||||
STARTER_MODELS,
|
||||
StarterModel,
|
||||
StarterModelBundle,
|
||||
StarterModelWithoutDependencies,
|
||||
)
|
||||
|
||||
@@ -86,7 +87,6 @@ example_model_config = {
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"hash": "string",
|
||||
"file_size": 1,
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"converted_at": 0,
|
||||
@@ -292,7 +292,7 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -450,7 +450,7 @@ async def install_model(
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
examples=[{"name": "string", "description": "string"}],
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
@@ -800,7 +800,7 @@ async def convert_model(
|
||||
|
||||
class StarterModelResponse(BaseModel):
|
||||
starter_models: list[StarterModel]
|
||||
starter_bundles: dict[str, StarterModelBundle]
|
||||
starter_bundles: dict[str, list[StarterModel]]
|
||||
|
||||
|
||||
def get_is_installed(
|
||||
@@ -834,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
for bundle in starter_bundles.values():
|
||||
for model in bundle.models:
|
||||
for model in bundle:
|
||||
model.is_installed = get_is_installed(model, installed_models)
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[StarterModelWithoutDependencies] = []
|
||||
@@ -858,18 +858,6 @@ async def get_stats() -> Optional[CacheStats]:
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/empty_model_cache",
|
||||
operation_id="empty_model_cache",
|
||||
status_code=200,
|
||||
)
|
||||
async def empty_model_cache() -> None:
|
||||
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
|
||||
# Request 1000GB of room in order to force the cache to drop all models.
|
||||
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
|
||||
ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30)
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
@@ -894,12 +882,6 @@ class HFTokenHelper:
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
@classmethod
|
||||
def reset_token(cls) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.logout()
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
@@ -922,8 +904,3 @@ async def do_hf_login(
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
|
||||
async def reset_hf_token() -> HFTokenStatus:
|
||||
return HFTokenHelper.reset_token()
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
"""FastAPI route for model relationship records."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
|
||||
|
||||
# === Schemas ===
|
||||
|
||||
|
||||
class ModelRelationshipCreateRequest(BaseModel):
|
||||
model_key_1: str = Field(
|
||||
...,
|
||||
description="The key of the first model in the relationship",
|
||||
examples=[
|
||||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
||||
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
||||
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
|
||||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
||||
"12345678-90ab-cdef-1234-567890abcdef",
|
||||
"fedcba98-7654-3210-fedc-ba9876543210",
|
||||
],
|
||||
)
|
||||
model_key_2: str = Field(
|
||||
...,
|
||||
description="The key of the second model in the relationship",
|
||||
examples=[
|
||||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
||||
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
|
||||
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
|
||||
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
|
||||
"10f7807b-ab54-46a9-ab03-600e88c630a1",
|
||||
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ModelRelationshipBatchRequest(BaseModel):
|
||||
model_keys: List[str] = Field(
|
||||
...,
|
||||
description="List of model keys to fetch related models for",
|
||||
examples=[
|
||||
[
|
||||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
||||
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
||||
],
|
||||
[
|
||||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
||||
"12345678-90ab-cdef-1234-567890abcdef",
|
||||
"fedcba98-7654-3210-fedc-ba9876543210",
|
||||
],
|
||||
[
|
||||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# === Routes ===
|
||||
|
||||
|
||||
@model_relationships_router.get(
|
||||
"/i/{model_key}",
|
||||
operation_id="get_related_models",
|
||||
response_model=list[str],
|
||||
responses={
|
||||
200: {
|
||||
"description": "A list of related model keys was retrieved successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": [
|
||||
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
|
||||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
||||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
404: {"description": "The specified model could not be found"},
|
||||
422: {"description": "Validation error"},
|
||||
},
|
||||
)
|
||||
async def get_related_models(
|
||||
model_key: str = Path(..., description="The key of the model to get relationships for"),
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get a list of model keys related to a given model.
|
||||
"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.post(
|
||||
"/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "The relationship was successfully created"},
|
||||
400: {"description": "Invalid model keys or self-referential relationship"},
|
||||
409: {"description": "The relationship already exists"},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Add Model Relationship",
|
||||
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
|
||||
)
|
||||
async def add_model_relationship(
|
||||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
|
||||
) -> None:
|
||||
"""
|
||||
Add a relationship between two models.
|
||||
|
||||
Relationships are bidirectional and will be accessible from both models.
|
||||
|
||||
- Raises 400 if keys are invalid or identical.
|
||||
- Raises 409 if the relationship already exists.
|
||||
"""
|
||||
try:
|
||||
if req.model_key_1 == req.model_key_2:
|
||||
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
|
||||
|
||||
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
|
||||
req.model_key_1,
|
||||
req.model_key_2,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.delete(
|
||||
"/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "The relationship was successfully removed"},
|
||||
400: {"description": "Invalid model keys or self-referential relationship"},
|
||||
404: {"description": "The relationship does not exist"},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Remove Model Relationship",
|
||||
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
|
||||
)
|
||||
async def remove_model_relationship(
|
||||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
|
||||
) -> None:
|
||||
"""
|
||||
Removes a bidirectional relationship between two model keys.
|
||||
|
||||
- Raises 400 if attempting to unlink a model from itself.
|
||||
- Raises 404 if the relationship was not found.
|
||||
"""
|
||||
try:
|
||||
if req.model_key_1 == req.model_key_2:
|
||||
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
|
||||
|
||||
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
|
||||
req.model_key_1,
|
||||
req.model_key_2,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.post(
|
||||
"/batch",
|
||||
operation_id="get_related_models_batch",
|
||||
response_model=List[str],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Related model keys retrieved successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": [
|
||||
"ca562b14-995e-4a42-90c1-9528f1a5921d",
|
||||
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
|
||||
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
|
||||
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
|
||||
"c382eaa3-0e28-4ab0-9446-408667699aeb",
|
||||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
||||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Get Related Model Keys (Batch)",
|
||||
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
|
||||
)
|
||||
async def get_related_models_batch(
|
||||
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
|
||||
) -> list[str]:
|
||||
"""
|
||||
Accepts multiple model keys and returns a flat list of all unique related keys.
|
||||
|
||||
Useful when working with multiple selections in the UI or cross-model comparisons.
|
||||
"""
|
||||
try:
|
||||
all_related: set[str] = set()
|
||||
for key in req.model_keys:
|
||||
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
|
||||
all_related.update(related)
|
||||
return list(all_related)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,31 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelAllExceptCurrentResult,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
ItemIdsResult,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
|
||||
@@ -37,12 +32,6 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class ValidationRunData(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow being published.")
|
||||
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
|
||||
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -54,89 +43,31 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
validation_run_data: Optional[ValidationRunData] = Body(
|
||||
default=None,
|
||||
description="The validation run data to use for this batch. This is only used if this is a validation run.",
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
try:
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list_all",
|
||||
operation_id="list_all_queue_items",
|
||||
"/{queue_id}/list",
|
||||
operation_id="list_queue_items",
|
||||
responses={
|
||||
200: {"model": list[SessionQueueItem]},
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_all_queue_items(
|
||||
async def list_queue_items(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
|
||||
limit: int = Query(default=50, description="The number of items to fetch"),
|
||||
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets all queue items (without graphs)"""
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/item_ids",
|
||||
operation_id="get_queue_item_ids",
|
||||
responses={
|
||||
200: {"model": ItemIdsResult},
|
||||
},
|
||||
)
|
||||
async def get_queue_item_ids(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
) -> ItemIdsResult:
|
||||
"""Gets all queue item ids that match the given parameters"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/items_by_ids",
|
||||
operation_id="get_queue_items_by_item_ids",
|
||||
responses={200: {"model": list[SessionQueueItem]}},
|
||||
)
|
||||
async def get_queue_items_by_item_ids(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(
|
||||
embed=True, description="Object containing list of queue item ids to fetch queue items for"
|
||||
),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets queue items for the specified queue item ids. Maintains order of item ids."""
|
||||
try:
|
||||
session_queue_service = ApiDependencies.invoker.services.session_queue
|
||||
|
||||
# Fetch queue items preserving the order of requested item ids
|
||||
queue_items: list[SessionQueueItem] = []
|
||||
for item_id in item_ids:
|
||||
try:
|
||||
queue_item = session_queue_service.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
|
||||
continue
|
||||
queue_items.append(queue_item)
|
||||
except Exception:
|
||||
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
|
||||
continue
|
||||
|
||||
return queue_items
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get queue items")
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -148,10 +79,7 @@ async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -163,40 +91,7 @@ async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_all_except_current",
|
||||
operation_id="cancel_all_except_current",
|
||||
responses={200: {"model": CancelAllExceptCurrentResult}},
|
||||
)
|
||||
async def cancel_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
"""Immediately cancels all queue items except in-processing items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/delete_all_except_current",
|
||||
operation_id="delete_all_except_current",
|
||||
responses={200: {"model": DeleteAllExceptCurrentResult}},
|
||||
)
|
||||
async def delete_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
"""Immediately deletes all queue items except in-processing items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -209,12 +104,7 @@ async def cancel_by_batch_ids(
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
|
||||
queue_id=queue_id, batch_ids=batch_ids
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -227,28 +117,9 @@ async def cancel_by_destination(
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/retry_items_by_id",
|
||||
operation_id="retry_items_by_id",
|
||||
responses={200: {"model": RetryItemsResult}},
|
||||
)
|
||||
async def retry_items_by_id(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(description="The queue item ids to retry"),
|
||||
) -> RetryItemsResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -262,14 +133,11 @@ async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -283,10 +151,7 @@ async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -300,10 +165,7 @@ async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -317,10 +179,7 @@ async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -334,12 +193,9 @@ async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -354,10 +210,7 @@ async def get_batch_status(
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -373,30 +226,7 @@ async def get_queue_item(
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
return queue_item
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/i/{item_id}",
|
||||
operation_id="delete_queue_item",
|
||||
)
|
||||
async def delete_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to delete"),
|
||||
) -> None:
|
||||
"""Deletes a queue item"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -411,12 +241,8 @@ async def cancel_queue_item(
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -429,27 +255,6 @@ async def counts_by_destination(
|
||||
destination: str = Query(description="The destination to query"),
|
||||
) -> SessionQueueCountsByDestination:
|
||||
"""Gets the counts of queue items by destination"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/d/{destination}",
|
||||
operation_id="delete_by_destination",
|
||||
responses={200: {"model": DeleteByDestinationResult}},
|
||||
)
|
||||
async def delete_by_destination(
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Path(description="The destination to query"),
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all items with the given destination"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ async def parse_dynamicprompts(
|
||||
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
|
||||
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
seed: int | None = Body(None, description="The seed to use for random generation. Only used if not combinatorial"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
max_prompts = min(max_prompts, 10000)
|
||||
@@ -36,7 +35,7 @@ async def parse_dynamicprompts(
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(prompt, max_prompts=max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator(seed=seed)
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(prompt, num_images=max_prompts)
|
||||
except ParseException as e:
|
||||
prompts = [prompt]
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.videos_common import (
|
||||
DeleteVideosResult,
|
||||
StarredVideosResult,
|
||||
UnstarredVideosResult,
|
||||
VideoDTO,
|
||||
VideoIdsResult,
|
||||
VideoRecordChanges,
|
||||
)
|
||||
|
||||
videos_router = APIRouter(prefix="/v1/videos", tags=["videos"])
|
||||
|
||||
|
||||
@videos_router.patch(
|
||||
"/i/{video_id}",
|
||||
operation_id="update_video",
|
||||
response_model=VideoDTO,
|
||||
)
|
||||
async def update_video(
|
||||
video_id: str = Path(description="The id of the video to update"),
|
||||
video_changes: VideoRecordChanges = Body(description="The changes to apply to the video"),
|
||||
) -> VideoDTO:
|
||||
"""Updates a video"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.get(
|
||||
"/i/{video_id}",
|
||||
operation_id="get_video_dto",
|
||||
response_model=VideoDTO,
|
||||
)
|
||||
async def get_video_dto(
|
||||
video_id: str = Path(description="The id of the video to get"),
|
||||
) -> VideoDTO:
|
||||
"""Gets a video's DTO"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.post("/delete", operation_id="delete_videos_from_list", response_model=DeleteVideosResult)
|
||||
async def delete_videos_from_list(
|
||||
video_ids: list[str] = Body(description="The list of ids of videos to delete", embed=True),
|
||||
) -> DeleteVideosResult:
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.post("/star", operation_id="star_videos_in_list", response_model=StarredVideosResult)
|
||||
async def star_videos_in_list(
|
||||
video_ids: list[str] = Body(description="The list of ids of videos to star", embed=True),
|
||||
) -> StarredVideosResult:
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.post("/unstar", operation_id="unstar_videos_in_list", response_model=UnstarredVideosResult)
|
||||
async def unstar_videos_in_list(
|
||||
video_ids: list[str] = Body(description="The list of ids of videos to unstar", embed=True),
|
||||
) -> UnstarredVideosResult:
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.delete("/uncategorized", operation_id="delete_uncategorized_videos", response_model=DeleteVideosResult)
|
||||
async def delete_uncategorized_videos() -> DeleteVideosResult:
|
||||
"""Deletes all videos that are uncategorized"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.get("/", operation_id="list_video_dtos", response_model=OffsetPaginatedResults[VideoDTO])
|
||||
async def list_video_dtos(
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate videos."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find videos without a board.",
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of videos per page"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred videos first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[VideoDTO]:
|
||||
"""Lists video DTOs"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.get("/ids", operation_id="get_video_ids")
|
||||
async def get_video_ids(
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate videos."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find videos without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred videos first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> VideoIdsResult:
|
||||
"""Gets ordered list of video ids with metadata for optimistic updates"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@videos_router.post(
|
||||
"/videos_by_ids",
|
||||
operation_id="get_videos_by_ids",
|
||||
responses={200: {"model": list[VideoDTO]}},
|
||||
)
|
||||
async def get_videos_by_ids(
|
||||
video_ids: list[str] = Body(embed=True, description="Object containing list of video ids to fetch DTOs for"),
|
||||
) -> list[VideoDTO]:
|
||||
"""Gets video DTOs for the specified video ids. Maintains order of input ids."""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
@@ -1,10 +1,6 @@
|
||||
import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
@@ -14,14 +10,11 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemWithThumbnailDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowRecordWithThumbnailDTO,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import WorkflowThumbnailFileNotFoundException
|
||||
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
@@ -29,17 +22,15 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordWithThumbnailDTO},
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowRecordWithThumbnailDTO:
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow"""
|
||||
try:
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
@@ -66,11 +57,6 @@ async def delete_workflow(
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except WorkflowThumbnailFileNotFoundException:
|
||||
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
|
||||
pass
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
|
||||
|
||||
@@ -92,7 +78,7 @@ async def create_workflow(
|
||||
"/",
|
||||
operation_id="list_workflows",
|
||||
responses={
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]},
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
@@ -102,160 +88,10 @@ async def list_workflows(
|
||||
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
|
||||
),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories of workflow to get"),
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
query=query,
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
return ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by, direction=direction, page=page, per_page=per_page, query=query, category=category
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
WorkflowRecordListItemWithThumbnailDTO(
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
|
||||
**workflow.model_dump(),
|
||||
)
|
||||
)
|
||||
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
|
||||
items=workflows_with_thumbnails,
|
||||
total=workflows.total,
|
||||
page=workflows.page,
|
||||
pages=workflows.pages,
|
||||
per_page=workflows.per_page,
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.put(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="set_workflow_thumbnail",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def set_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
image: UploadFile = File(description="The image file to upload"),
|
||||
):
|
||||
"""Sets a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.save(workflow_id, pil_image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="delete_workflow_thumbnail",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def delete_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
):
|
||||
"""Removes a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="get_workflow_thumbnail",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The workflow thumbnail was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The workflow thumbnail could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a workflow's thumbnail image"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=workflow_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by tag"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
|
||||
async def counts_by_category(
|
||||
categories: list[WorkflowCategory] = Query(description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by category"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
|
||||
categories=categories, has_been_opened=has_been_opened
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.put(
|
||||
"/i/{workflow_id}/opened_at",
|
||||
operation_id="update_opened_at",
|
||||
)
|
||||
async def update_opened_at(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
) -> None:
|
||||
"""Updates the opened_at field of a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
@@ -11,36 +15,54 @@ from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
board_videos,
|
||||
boards,
|
||||
client_state,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
model_relationships,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
videos,
|
||||
workflows,
|
||||
)
|
||||
from invokeai.app.api.sockets import SocketIO
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
|
||||
# the correct port when the server starts in the lifespan handler.
|
||||
port = app_config.port
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -49,7 +71,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
|
||||
proto = "https" if app_config.ssl_certfile else "http"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
|
||||
|
||||
# Logging this way ignores the logger's log level and _always_ logs the message
|
||||
record = logger.makeRecord(
|
||||
@@ -127,16 +149,12 @@ app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(videos.videos_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(board_videos.board_videos_router, prefix="/api")
|
||||
app.include_router(model_relationships.model_relationships_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
app.include_router(client_state.client_state_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -161,16 +179,80 @@ def overridden_redoc() -> HTMLResponse:
|
||||
|
||||
web_root_path = Path(list(web_dir.__path__)[0])
|
||||
|
||||
if app_config.unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
|
||||
"loading models, which may allow malicious code to be executed. Use at your own risk."
|
||||
)
|
||||
|
||||
try:
|
||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||
except RuntimeError:
|
||||
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
|
||||
|
||||
def check_cudnn(logger: logging.Logger) -> None:
|
||||
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
||||
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
logger.info(f"cuDNN version: {cudnn_version}")
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
||||
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
||||
f"system. Full error message:\n{e}"
|
||||
)
|
||||
|
||||
|
||||
def invoke_api() -> None:
|
||||
def find_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
global port
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -1,5 +1,33 @@
|
||||
import shutil
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||
|
||||
# copy our custom nodes __init__.py to the custom nodes directory
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||
|
||||
# set the same permissions as the destination directory, in case our source is read-only,
|
||||
# so that the files are user-writable
|
||||
for p in custom_nodes_path.glob("**/*"):
|
||||
p.chmod(custom_nodes_path.stat().st_mode)
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# add core nodes to __all__
|
||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||
__all__ = [f.stem for f in python_files] # type: ignore
|
||||
|
||||
@@ -5,12 +5,9 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -22,23 +19,19 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
InputFieldJSONSchemaExtra,
|
||||
UIType,
|
||||
migrate_model_ui_type,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -51,6 +44,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
@@ -79,24 +74,13 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Special = "special"
|
||||
|
||||
|
||||
class Bottleneck(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The bottleneck of an invocation.
|
||||
- `Network`: The invocation's execution is network-bound.
|
||||
- `GPU`: The invocation's execution is GPU-bound.
|
||||
"""
|
||||
|
||||
Network = "network"
|
||||
GPU = "gpu"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @invocation decorator logic. Do not use this directly.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: str = Field(
|
||||
@@ -111,11 +95,6 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class OriginalModelField(TypedDict):
|
||||
annotation: Any
|
||||
field_info: FieldInfo
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
Base class for all invocation outputs.
|
||||
@@ -123,11 +102,36 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
output_meta: Optional[dict[str, JsonValue]] = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
@@ -144,9 +148,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
|
||||
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
@@ -174,13 +175,68 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@@ -212,7 +268,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
"""
|
||||
for field_name, field in type(self).model_fields.items():
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
@@ -227,9 +283,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(type(self).model_fields["type"].default, field_name)
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if services.configuration.node_cache_size == 0:
|
||||
@@ -259,9 +315,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
|
||||
).model_dump(exclude_none=True),
|
||||
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
@@ -269,8 +323,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
bottleneck: ClassVar[Bottleneck]
|
||||
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -281,163 +333,21 @@ class BaseInvocation(ABC, BaseModel):
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
|
||||
"""The original model fields, before any modifications were made by the @invocation decorator."""
|
||||
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
class InvocationRegistry:
|
||||
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
|
||||
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
|
||||
invocation_type = invocation.get_type()
|
||||
node_pack = invocation.UIConfig.node_pack
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
if clobbered_invocation is not None:
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The invocation being clobbered is a core invocation
|
||||
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
|
||||
else:
|
||||
# The invocation being clobbered is a custom invocation
|
||||
logger.warning(
|
||||
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
|
||||
)
|
||||
cls._invocation_classes.remove(clobbered_invocation)
|
||||
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation types.
|
||||
|
||||
This is used to parse serialized invocations into the correct invocation class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_invocation_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation type adapter."""
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[type[BaseInvocation]] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in cls.get_invocation_classes()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in cls.get_invocation_classes())
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
output_type = output.get_type()
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
|
||||
if clobbered_output is not None:
|
||||
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
|
||||
logger.warning(f'Overriding invocation output "{output_type}"')
|
||||
cls._output_classes.remove(clobbered_output)
|
||||
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@classmethod
|
||||
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
|
||||
"""Gets a map of all output types to their output classes."""
|
||||
return {i.get_type(): i for i in cls.get_output_classes()}
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
|
||||
|
||||
This is used to parse serialized invocation outputs into the correct invocation output class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_output_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation output type adapter."""
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
@classmethod
|
||||
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
|
||||
"""Gets the output class for a given output type."""
|
||||
return cls.get_outputs_map().get(output_type)
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
"bottleneck",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
@@ -450,15 +360,6 @@ with warnings.catch_warnings():
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
|
||||
|
||||
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
|
||||
"""Checks if a value is a member of an enum class."""
|
||||
try:
|
||||
enum_class(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
@@ -470,144 +371,54 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
if not field.annotation:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
|
||||
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
|
||||
|
||||
if not isinstance(field.json_schema_extra, dict):
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
|
||||
)
|
||||
|
||||
field_kind = field.json_schema_extra.get("field_kind", None)
|
||||
|
||||
# must have a field_kind
|
||||
if not is_enum_member(field_kind, FieldKind):
|
||||
if not isinstance(field_kind, FieldKind):
|
||||
raise InvalidFieldError(
|
||||
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
)
|
||||
|
||||
if field_kind == FieldKind.Input.value and (
|
||||
if field_kind is FieldKind.Input and (
|
||||
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
|
||||
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
|
||||
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
|
||||
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
|
||||
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
|
||||
# node attribute fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind == FieldKind.NodeAttribute.value
|
||||
field_kind is FieldKind.NodeAttribute
|
||||
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
|
||||
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
|
||||
)
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
|
||||
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
|
||||
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
|
||||
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
|
||||
|
||||
if ui_type is not None:
|
||||
# There are 3 cases where we may need to take action:
|
||||
#
|
||||
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
|
||||
# deprecated and should be migrated to:
|
||||
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
|
||||
# - ui_model_type=[ModelType.Main]
|
||||
#
|
||||
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
|
||||
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
|
||||
#
|
||||
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
|
||||
# Image fields are now automatically detected based on the field's type annotation. In this case, we
|
||||
# ignore ui_type and log a warning.
|
||||
#
|
||||
# The cases must be checked in this order to ensure proper handling.
|
||||
|
||||
# Easier to work with as an enum
|
||||
ui_type = UIType(ui_type)
|
||||
|
||||
# The enum member values are not always the same as their names - we want to log the name so the user can
|
||||
# easily review their code and see where the deprecated enum member is used.
|
||||
human_readable_name = f"UIType.{ui_type.name}"
|
||||
|
||||
# Case 1: migratable deprecated value
|
||||
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
|
||||
|
||||
if did_migrate:
|
||||
logger.warning(
|
||||
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
|
||||
)
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
# Case 2: mutually exclusive with new fields
|
||||
elif (
|
||||
ui_model_base is not None
|
||||
or ui_model_type is not None
|
||||
or ui_model_variant is not None
|
||||
or ui_model_format is not None
|
||||
):
|
||||
logger.warning(
|
||||
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
|
||||
)
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
# Case 3: deprecated value that is not migratable
|
||||
elif ui_type.startswith("DEPRECATED_"):
|
||||
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
return None
|
||||
|
||||
|
||||
class NoDefaultSentinel:
|
||||
pass
|
||||
|
||||
|
||||
def validate_field_default(
|
||||
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
|
||||
) -> None:
|
||||
"""Validates the default value of a field against its pydantic field definition."""
|
||||
|
||||
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
|
||||
|
||||
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
|
||||
# We store the original default value in the json_schema_extra dict, so we can validate it here.
|
||||
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
|
||||
|
||||
if orig_default is NoDefaultSentinel:
|
||||
return
|
||||
|
||||
# To validate the default value, we can create a temporary pydantic model with the field we are validating as its
|
||||
# only field. Then validate the default value against this temporary model.
|
||||
TempDefaultValidator = cast(BaseModel, create_model(cls_name, **{field_name: (annotation, field_info)}))
|
||||
|
||||
try:
|
||||
TempDefaultValidator.model_validate({field_name: orig_default})
|
||||
except Exception as e:
|
||||
raise InvalidFieldError(
|
||||
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
|
||||
) from e
|
||||
|
||||
|
||||
def is_optional(annotation: Any) -> bool:
|
||||
"""
|
||||
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
|
||||
"""
|
||||
origin = typing.get_origin(annotation)
|
||||
# PEP 604 unions (int|None) have origin types.UnionType
|
||||
is_union = origin is typing.Union or origin is types.UnionType
|
||||
if not is_union:
|
||||
return False
|
||||
return any(arg is type(None) for arg in typing.get_args(annotation))
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
@@ -616,7 +427,6 @@ def invocation(
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
classification: Classification = Classification.Stable,
|
||||
bottleneck: Bottleneck = Bottleneck.GPU,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Registers an invocation.
|
||||
@@ -628,7 +438,6 @@ def invocation(
|
||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
|
||||
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
@@ -637,38 +446,19 @@ def invocation(
|
||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
original_model_fields: dict[str, OriginalModelField] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
|
||||
|
||||
if field_info.default is None and not is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
uiconfig["tags"] = tags
|
||||
uiconfig["category"] = category
|
||||
uiconfig["classification"] = classification
|
||||
uiconfig["node_pack"] = node_pack
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
@@ -677,7 +467,7 @@ def invocation(
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
uiconfig["version"] = version
|
||||
else:
|
||||
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
uiconfig["version"] = "1.0.0"
|
||||
|
||||
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||
@@ -685,8 +475,6 @@ def invocation(
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
cls.bottleneck = bottleneck
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
@@ -696,27 +484,11 @@ def invocation(
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type]
|
||||
|
||||
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
|
||||
# This cast makes the type annotation match the class's true type.
|
||||
invocation_type_field_info = cast(
|
||||
FieldInfo,
|
||||
Field(title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
|
||||
|
||||
# Invocation outputs must be registered using the @invocation_output decorator, but it is possible that the
|
||||
# output is registered _after_ this invocation is registered. It depends on module import ordering.
|
||||
#
|
||||
# We can only confirm the output for an invocation is registered after all modules are imported. There's
|
||||
# only really one good time to do that - during application startup, in `run_app.py`, after loading all
|
||||
# custom nodes.
|
||||
#
|
||||
# We can still do some basic validation here - ensure the invoke method is defined and returns an instance
|
||||
# of BaseInvocationOutput.
|
||||
|
||||
# Validate the `invoke()` method is implemented
|
||||
if "invoke" in cls.__abstractmethods__:
|
||||
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
|
||||
@@ -738,13 +510,18 @@ def invocation(
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
|
||||
new_class.__doc__ = docstring
|
||||
new_class._original_model_fields = original_model_fields
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
InvocationRegistry.register_invocation(new_class)
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
|
||||
return new_class
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -767,41 +544,29 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
if field_info.default is not PydanticUndefined and is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add the output type to the model.
|
||||
output_type_annotation = Literal[output_type]
|
||||
|
||||
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
|
||||
# This cast makes the type annotation match the class's true type.
|
||||
output_type_field_info = cast(
|
||||
FieldInfo,
|
||||
Field(title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
fields["type"] = (output_type_annotation, output_type_field_info)
|
||||
|
||||
docstring = cls.__doc__
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
|
||||
new_class.__doc__ = docstring
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
InvocationRegistry.register_output(new_class)
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
|
||||
return new_class
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatOutput,
|
||||
ImageOutput,
|
||||
IntegerOutput,
|
||||
StringOutput,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
BATCH_GROUP_IDS = Literal[
|
||||
"None",
|
||||
"Group 1",
|
||||
"Group 2",
|
||||
"Group 3",
|
||||
"Group 4",
|
||||
"Group 5",
|
||||
]
|
||||
|
||||
|
||||
class NotExecutableNodeError(Exception):
|
||||
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
|
||||
super().__init__(message)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseBatchInvocation(BaseInvocation):
|
||||
batch_group_id: BATCH_GROUP_IDS = InputField(
|
||||
default="None",
|
||||
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
|
||||
input=Input.Direct,
|
||||
title="Batch Group",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
min_length=1,
|
||||
description="The images to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("image_generator_output")
|
||||
class ImageGeneratorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of boards"""
|
||||
|
||||
images: list[ImageField] = OutputField(description="The generated images")
|
||||
|
||||
|
||||
class ImageGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_generator",
|
||||
title="Image Generator",
|
||||
tags=["primitives", "board", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageGenerator(BaseInvocation):
|
||||
"""Generated a collection of images for use in a batched generation"""
|
||||
|
||||
generator: ImageGeneratorField = InputField(
|
||||
description="The image generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(
|
||||
min_length=1,
|
||||
description="The strings to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("string_generator_output")
|
||||
class StringGeneratorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
strings: list[str] = OutputField(description="The generated strings")
|
||||
|
||||
|
||||
class StringGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_generator",
|
||||
title="String Generator",
|
||||
tags=["primitives", "string", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringGenerator(BaseInvocation):
|
||||
"""Generated a range of strings for use in a batched generation"""
|
||||
|
||||
generator: StringGeneratorField = InputField(
|
||||
description="The string generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
|
||||
|
||||
integers: list[int] = InputField(
|
||||
min_length=1,
|
||||
description="The integers to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("integer_generator_output")
|
||||
class IntegerGeneratorOutput(BaseInvocationOutput):
|
||||
integers: list[int] = OutputField(description="The generated integers")
|
||||
|
||||
|
||||
class IntegerGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_generator",
|
||||
title="Integer Generator",
|
||||
tags=["primitives", "int", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerGenerator(BaseInvocation):
|
||||
"""Generated a range of integers for use in a batched generation"""
|
||||
|
||||
generator: IntegerGeneratorField = InputField(
|
||||
description="The integer generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
|
||||
|
||||
floats: list[float] = InputField(
|
||||
min_length=1,
|
||||
description="The floats to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("float_generator_output")
|
||||
class FloatGeneratorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
floats: list[float] = OutputField(description="The generated floats")
|
||||
|
||||
|
||||
class FloatGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_generator",
|
||||
title="Float Generator",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatGenerator(BaseInvocation):
|
||||
"""Generated a range of floats for use in a batched generation"""
|
||||
|
||||
generator: FloatGeneratorField = InputField(
|
||||
description="The float generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
@@ -1,363 +0,0 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_cogview4 import CogView4Transformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
CogView4ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"cogview4_denoise",
|
||||
title="Denoise - CogView4",
|
||||
tags=["image", "cogview4"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run the denoising process with a CogView4 model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.cogview4_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: CogView4ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: CogView4ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=32, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=32, description="Height of the generated image.")
|
||||
steps: int = InputField(default=25, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
||||
inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
||||
represent the regions to be preserved.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
||||
# 1.0 represents the regions to be preserved.
|
||||
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
cogview4_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(cogview4_conditioning, CogView4ConditioningInfo)
|
||||
cogview4_conditioning = cogview4_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
return cogview4_conditioning.glm_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _convert_timesteps_to_sigmas(self, image_seq_len: int, timesteps: torch.Tensor) -> list[float]:
|
||||
# The logic to prepare the timestep / sigma schedule is based on:
|
||||
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
|
||||
# The default FlowMatchEulerDiscreteScheduler configs are based on:
|
||||
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
|
||||
# This implementation differs slightly from the original for the sake of simplicity (differs in terminal value
|
||||
# handling, not quantizing timesteps to integers, etc.).
|
||||
|
||||
def calculate_timestep_shift(
|
||||
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
|
||||
) -> float:
|
||||
m = (image_seq_len / base_seq_len) ** 0.5
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
def time_shift_linear(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
mu = calculate_timestep_shift(image_seq_len)
|
||||
sigmas = time_shift_linear(mu, 1.0, timesteps)
|
||||
return sigmas.tolist()
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
assert isinstance(transformer_info.model, CogView4Transformer2DModel)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Prepare misc. conditioning variables.
|
||||
# TODO(ryand): We could expose these as params (like with SDXL). But, we should experiment to see if they are
|
||||
# useful first.
|
||||
original_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
|
||||
# Prepare the timestep / sigma schedule.
|
||||
patch_size = transformer_info.model.config.patch_size # type: ignore
|
||||
assert isinstance(patch_size, int)
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
sigmas = self._convert_timesteps_to_sigmas(image_seq_len, torch.tensor(timesteps))
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels # type: ignore
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
batch_size=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
s_0 = sigmas[0]
|
||||
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (_, transformer):
|
||||
assert isinstance(transformer, CogView4Transformer2DModel)
|
||||
|
||||
# Denoising loop
|
||||
for step_idx in tqdm(range(total_steps)):
|
||||
t_curr = timesteps[step_idx]
|
||||
sigma_curr = sigmas[step_idx]
|
||||
sigma_prev = sigmas[step_idx + 1]
|
||||
|
||||
# Expand the timestep to match the latent model input.
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
|
||||
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
noise_pred_cond = transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=pos_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond = transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=neg_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
# TODO(ryand): Is casting to float32 necessary for precision/stability? I copied this from SD3.
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.CogView4)
|
||||
|
||||
return step_callback
|
||||
@@ -1,76 +0,0 @@
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
|
||||
|
||||
@invocation(
|
||||
"cogview4_i2l",
|
||||
title="Image to Latents - CogView4",
|
||||
tags=["image", "latents", "vae", "i2l", "cogview4"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
# TODO: Use seed to make sampling reproducible.
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -1,79 +0,0 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
|
||||
|
||||
@invocation(
|
||||
"cogview4_l2i",
|
||||
title="Latents to Image - CogView4",
|
||||
tags=["latents", "image", "vae", "l2i", "cogview4"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
):
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(TorchDevice.choose_torch_device())
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,57 +0,0 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
GlmEncoderField,
|
||||
ModelIdentifierField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("cogview4_model_loader_output")
|
||||
class CogView4ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""CogView4 base model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
glm_encoder: GlmEncoderField = OutputField(description=FieldDescriptions.glm_encoder, title="GLM Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"cogview4_model_loader",
|
||||
title="Main Model - CogView4",
|
||||
tags=["model", "cogview4"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CogView4ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a CogView4 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.cogview4_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.CogView4,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
glm_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
glm_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
return CogView4ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
glm_encoder=GlmEncoderField(tokenizer=glm_tokenizer, text_encoder=glm_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
from transformers import GlmModel, PreTrainedTokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.model import GlmEncoderField
|
||||
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
|
||||
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
|
||||
|
||||
|
||||
@invocation(
|
||||
"cogview4_text_encoder",
|
||||
title="Prompt - CogView4",
|
||||
tags=["prompt", "conditioning", "cogview4"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a cogview4 image."""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
glm_encoder: GlmEncoderField = InputField(
|
||||
title="GLM Encoder",
|
||||
description=FieldDescriptions.glm_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
|
||||
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return CogView4ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
prompt = [self.prompt]
|
||||
|
||||
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
|
||||
with (
|
||||
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
|
||||
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
|
||||
):
|
||||
context.util.signal_progress("Running GLM text encoder")
|
||||
assert isinstance(glm_text_encoder, GlmModel)
|
||||
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
|
||||
|
||||
text_inputs = glm_tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = glm_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = glm_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
current_length = text_input_ids.shape[1]
|
||||
pad_length = (16 - (current_length % 16)) % 16
|
||||
if pad_length > 0:
|
||||
pad_ids = torch.full(
|
||||
(text_input_ids.shape[0], pad_length),
|
||||
fill_value=glm_tokenizer.pad_token_id,
|
||||
dtype=text_input_ids.dtype,
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = glm_text_encoder(
|
||||
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType, SplitLongTextMode
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
@@ -40,10 +40,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"compel",
|
||||
title="Prompt - SD1.5",
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -104,7 +104,6 @@ class CompelInvocation(BaseInvocation):
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -114,13 +113,6 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
del compel
|
||||
del patched_tokenizer
|
||||
del tokenizer
|
||||
del ti_manager
|
||||
del text_encoder
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
@@ -213,7 +205,6 @@ class SDXLPromptInvocationBase:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -229,10 +220,7 @@ class SDXLPromptInvocationBase:
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
del compel
|
||||
del patched_tokenizer
|
||||
del tokenizer
|
||||
del ti_manager
|
||||
del text_encoder
|
||||
del text_encoder_info
|
||||
|
||||
@@ -245,10 +233,10 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
@invocation(
|
||||
"sdxl_compel_prompt",
|
||||
title="Prompt - SDXL",
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -339,10 +327,10 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@invocation(
|
||||
"sdxl_refiner_compel_prompt",
|
||||
title="Prompt - SDXL Refiner",
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.2",
|
||||
version="1.1.1",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -388,10 +376,10 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"clip_skip",
|
||||
title="Apply CLIP Skip - SD1.5, SDXL",
|
||||
title="CLIP Skip",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.1.0",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
@@ -525,7 +513,7 @@ def log_tokenization_for_text(
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f"\n>> [TOKENLOG] Tokens {display_label or ''} ({usedTokens}):")
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
|
||||
@@ -274,12 +274,12 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
|
||||
title="Enhance Image",
|
||||
tags=["enhance", "image"],
|
||||
category="image",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
|
||||
|
||||
image: ImageField = InputField(description="The image for which to apply processing")
|
||||
image: ImageField = InputField(default=None, description="The image for which to apply processing")
|
||||
invert: bool = InputField(default=False, description="Whether to invert the image colors")
|
||||
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
|
||||
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
heuristic_resize_fast,
|
||||
)
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation(
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
||||
)
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize_fast(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (
|
||||
ContentShuffleDetector,
|
||||
LeresDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_image_processor",
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_image_processor",
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_image_processor",
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_image_processor",
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"normalbae_image_processor",
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle_image_processor",
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@invocation(
|
||||
"zoe_depth_image_processor",
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_processor",
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
max_faces=self.max_faces,
|
||||
min_confidence=self.min_confidence,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"leres_image_processor",
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"tile_image_processor",
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
H, W, C = np_img.shape
|
||||
H = int(float(H) / float(down_sampling_rate))
|
||||
W = int(float(W) / float(down_sampling_rate))
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_processor",
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_image_processor",
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -21,8 +19,7 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
@@ -44,13 +41,15 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising."""
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
mask: ImageField = InputField(description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(default=16, ge=0, description="How far to expand the edges of the mask", ui_order=2)
|
||||
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(
|
||||
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||
)
|
||||
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
@@ -81,110 +80,45 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
# Resize the mask_image. Makes the filter 64x faster and doesn't hurt quality in latent scale anyway
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // LATENT_SCALE_FACTOR,
|
||||
mask_image.height // LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.BILINEAR,
|
||||
)
|
||||
|
||||
mask_np_orig = np.array(mask_image, dtype=np.float32)
|
||||
|
||||
self.edge_radius = self.edge_radius // LATENT_SCALE_FACTOR # scale the edge radius to match the mask size
|
||||
|
||||
if self.edge_radius > 0:
|
||||
mask_np = 255 - mask_np_orig # invert so 0 is unmasked (higher values = higher denoise strength)
|
||||
dilated_mask = mask_np.copy()
|
||||
|
||||
# Create kernel based on coherence mode
|
||||
if self.coherence_mode == "Box Blur":
|
||||
# Create a circular distance kernel that fades from center outward
|
||||
kernel_size = self.edge_radius * 2 + 1
|
||||
center = self.edge_radius
|
||||
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
||||
for i in range(kernel_size):
|
||||
for j in range(kernel_size):
|
||||
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
|
||||
if dist <= self.edge_radius:
|
||||
kernel[i, j] = 1.0 - (dist / self.edge_radius)
|
||||
else: # Gaussian Blur or Staged
|
||||
# Create a Gaussian kernel
|
||||
kernel_size = self.edge_radius * 2 + 1
|
||||
kernel = cv2.getGaussianKernel(
|
||||
kernel_size, self.edge_radius / 2.5
|
||||
) # 2.5 is a magic number (standard deviation capturing)
|
||||
kernel = kernel * kernel.T # Make 2D gaussian kernel
|
||||
kernel = kernel / np.max(kernel) # Normalize center to 1.0
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
# Ensure values outside radius are 0
|
||||
center = self.edge_radius
|
||||
for i in range(kernel_size):
|
||||
for j in range(kernel_size):
|
||||
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
|
||||
if dist > self.edge_radius:
|
||||
kernel[i, j] = 0
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# 2D max filter
|
||||
mask_tensor = torch.tensor(mask_np)
|
||||
kernel_tensor = torch.tensor(kernel)
|
||||
dilated_mask = 255 - self.max_filter2D_torch(mask_tensor, kernel_tensor).cpu()
|
||||
dilated_mask = dilated_mask.numpy()
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
blur_tensor[blur_tensor < 0] = 0.0
|
||||
|
||||
threshold = (1 - self.minimum_denoise) * 255
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever expanded mask is darker than the original mask but original was above threshhold, set it to the threshold
|
||||
# makes any expansion areas drop to threshhold. Raising minimum across the image happen outside of this if
|
||||
threshold_mask = (dilated_mask < mask_np_orig) & (mask_np_orig > threshold)
|
||||
dilated_mask = np.where(threshold_mask, threshold, mask_np_orig)
|
||||
|
||||
# wherever expanded mask is less than 255 but greater than threshold, drop it to threshold (minimum denoise)
|
||||
threshold_mask = (dilated_mask > threshold) & (dilated_mask < 255)
|
||||
dilated_mask = np.where(threshold_mask, threshold, dilated_mask)
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
else:
|
||||
dilated_mask = mask_np_orig.copy()
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
# convert to tensor
|
||||
dilated_mask = np.clip(dilated_mask, 0, 255).astype(np.uint8)
|
||||
mask_tensor = torch.tensor(dilated_mask, device=torch.device("cpu"))
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
# binary mask for compositing
|
||||
expanded_mask = np.where((dilated_mask < 255), 0, 255)
|
||||
expanded_mask_image = Image.fromarray(expanded_mask.astype(np.uint8), mode="L")
|
||||
expanded_mask_image = expanded_mask_image.resize(
|
||||
(
|
||||
mask_image.width * LATENT_SCALE_FACTOR,
|
||||
mask_image.height * LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.NEAREST,
|
||||
)
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
# restore the original mask size
|
||||
dilated_mask = Image.fromarray(dilated_mask.astype(np.uint8))
|
||||
dilated_mask = dilated_mask.resize(
|
||||
(
|
||||
mask_image.width * LATENT_SCALE_FACTOR,
|
||||
mask_image.height * LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.NEAREST,
|
||||
)
|
||||
|
||||
# stack the mask as a tensor, repeating 4 times on dimmension 1
|
||||
dilated_mask_tensor = image_resized_to_grid_as_tensor(dilated_mask, normalize=False)
|
||||
mask_name = context.tensors.save(tensor=dilated_mask_tensor.unsqueeze(0))
|
||||
|
||||
masked_latents_name = None
|
||||
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||
# all three fields must be present at the same time
|
||||
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = dilated_mask_tensor
|
||||
mask = blur_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@@ -202,29 +136,3 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
def max_filter2D_torch(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This morphological operation is much faster in torch than numpy or opencv
|
||||
For reasonable kernel sizes, the overhead of copying the data to the GPU is not worth it.
|
||||
"""
|
||||
h, w = kernel.shape
|
||||
pad_h, pad_w = h // 2, w // 2
|
||||
|
||||
padded = torch.nn.functional.pad(image, (pad_w, pad_w, pad_h, pad_h), mode="constant", value=0)
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
# This looks like it's inside out, but it does the same thing and is more efficient
|
||||
for i in range(h):
|
||||
for j in range(w):
|
||||
weight = kernel[i, j]
|
||||
if weight <= 0:
|
||||
continue
|
||||
|
||||
# Extract the region from padded tensor
|
||||
region = padded[i : i + image.shape[0], j : j + image.shape[1]]
|
||||
|
||||
# Apply weight and update max
|
||||
result = torch.maximum(result, region * weight)
|
||||
|
||||
return result
|
||||
|
||||
58
invokeai/app/invocations/custom_nodes/init.py
Normal file
58
invokeai/app/invocations/custom_nodes/init.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
for d in Path(__file__).parent.iterdir():
|
||||
# skip files
|
||||
if not d.is_dir():
|
||||
continue
|
||||
|
||||
# skip hidden directories
|
||||
if d.name.startswith("_") or d.name.startswith("."):
|
||||
continue
|
||||
|
||||
# skip directories without an `__init__.py`
|
||||
init = d / "__init__.py"
|
||||
if not init.exists():
|
||||
continue
|
||||
|
||||
module_name = init.parent.stem
|
||||
|
||||
# skip if already imported
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
except Exception:
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
|
||||
|
||||
del init, module_name
|
||||
|
||||
if loaded_count > 0:
|
||||
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
@@ -39,8 +39,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -86,7 +85,6 @@ def get_scheduler(
|
||||
scheduler_info: ModelIdentifierField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
unet_config: AnyModelConfig,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
@@ -105,9 +103,6 @@ def get_scheduler(
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
if hasattr(unet_config, "prediction_type"):
|
||||
scheduler_config["prediction_type"] = unet_config.prediction_type
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
@@ -127,10 +122,10 @@ def get_scheduler(
|
||||
|
||||
@invocation(
|
||||
"denoise_latents",
|
||||
title="Denoise - SD1.5, SDXL",
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.5.4",
|
||||
version="1.5.3",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@@ -608,7 +603,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
mask=mask,
|
||||
method=single_ip_adapter.method,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -835,9 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
@@ -857,7 +848,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
@@ -869,6 +859,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
@@ -899,7 +892,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inspecting the shape of the loaded UNet model weights. Now we
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||
@@ -1037,7 +1030,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
|
||||
@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@@ -46,6 +38,27 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
@@ -54,7 +67,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
_IsIntermediate = "IsIntermediate"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
@@ -92,44 +104,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
|
||||
# Deprecated Model Field Types - use ui_model_[base|type|variant|format] instead
|
||||
MainModel = "DEPRECATED_MainModelField"
|
||||
CogView4MainModel = "DEPRECATED_CogView4MainModelField"
|
||||
FluxMainModel = "DEPRECATED_FluxMainModelField"
|
||||
SD3MainModel = "DEPRECATED_SD3MainModelField"
|
||||
SDXLMainModel = "DEPRECATED_SDXLMainModelField"
|
||||
SDXLRefinerModel = "DEPRECATED_SDXLRefinerModelField"
|
||||
ONNXModel = "DEPRECATED_ONNXModelField"
|
||||
VAEModel = "DEPRECATED_VAEModelField"
|
||||
FluxVAEModel = "DEPRECATED_FluxVAEModelField"
|
||||
LoRAModel = "DEPRECATED_LoRAModelField"
|
||||
ControlNetModel = "DEPRECATED_ControlNetModelField"
|
||||
IPAdapterModel = "DEPRECATED_IPAdapterModelField"
|
||||
T2IAdapterModel = "DEPRECATED_T2IAdapterModelField"
|
||||
T5EncoderModel = "DEPRECATED_T5EncoderModelField"
|
||||
CLIPEmbedModel = "DEPRECATED_CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "DEPRECATED_CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "DEPRECATED_CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "DEPRECATED_SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "DEPRECATED_ControlLoRAModelField"
|
||||
SigLipModel = "DEPRECATED_SigLipModelField"
|
||||
FluxReduxModel = "DEPRECATED_FluxReduxModelField"
|
||||
LlavaOnevisionModel = "DEPRECATED_LLaVAModelField"
|
||||
Imagen3Model = "DEPRECATED_Imagen3ModelField"
|
||||
Imagen4Model = "DEPRECATED_Imagen4ModelField"
|
||||
ChatGPT4oModel = "DEPRECATED_ChatGPT4oModelField"
|
||||
Gemini2_5Model = "DEPRECATED_Gemini2_5ModelField"
|
||||
FluxKontextModel = "DEPRECATED_FluxKontextModelField"
|
||||
Veo3Model = "DEPRECATED_Veo3ModelField"
|
||||
RunwayModel = "DEPRECATED_RunwayModelField"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
@@ -153,7 +134,6 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
@@ -168,12 +148,10 @@ class FieldDescriptions:
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
cogview4_model = "CogView4 model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
vllm_model = "VLLM model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
@@ -223,10 +201,6 @@ class FieldDescriptions:
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -235,12 +209,6 @@ class ImageField(BaseModel):
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class VideoField(BaseModel):
|
||||
"""A video primitive field"""
|
||||
|
||||
video_id: str = Field(description="The id of the video")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
@@ -291,42 +259,12 @@ class FluxConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FluxReduxConditioningField(BaseModel):
|
||||
"""A FLUX Redux conditioning tensor primitive value"""
|
||||
|
||||
conditioning: TensorField = Field(description="The Redux image conditioning tensor.")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class FluxFillConditioningField(BaseModel):
|
||||
"""A FLUX Fill conditioning field."""
|
||||
|
||||
image: ImageField = Field(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class FluxKontextConditioningField(BaseModel):
|
||||
"""A conditioning field for FLUX Kontext (reference image)."""
|
||||
|
||||
image: ImageField = Field(description="The Kontext reference image.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class CogView4ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@@ -338,9 +276,14 @@ class ConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxField(BoundingBox):
|
||||
class BoundingBoxField(BaseModel):
|
||||
"""A bounding box primitive value."""
|
||||
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
score: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
@@ -349,6 +292,14 @@ class BoundingBoxField(BoundingBox):
|
||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_coords(self):
|
||||
if self.x_min > self.x_max:
|
||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||
if self.y_min > self.y_max:
|
||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||
return self
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
@@ -406,8 +357,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
orig_required: bool = True
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
@@ -415,15 +366,10 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
ui_model_base: Optional[list[BaseModelType]] = None
|
||||
ui_model_type: Optional[list[ModelType]] = None
|
||||
ui_model_variant: Optional[list[ClipVariantType | ModelVariantType]] = None
|
||||
ui_model_format: Optional[list[ModelFormat]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
use_enum_values=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -447,7 +393,7 @@ class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
@@ -476,121 +422,16 @@ class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool = False
|
||||
ui_order: Optional[int] = None
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
use_enum_values=True,
|
||||
)
|
||||
|
||||
|
||||
def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, Any]) -> bool:
|
||||
"""Migrate deprecated model-specifier ui_type values to new-style ui_model_[base|type|variant|format] in json_schema_extra."""
|
||||
if not isinstance(ui_type, UIType):
|
||||
ui_type = UIType(ui_type)
|
||||
|
||||
ui_model_type: list[ModelType] | None = None
|
||||
ui_model_base: list[BaseModelType] | None = None
|
||||
ui_model_format: list[ModelFormat] | None = None
|
||||
ui_model_variant: list[ClipVariantType | ModelVariantType] | None = None
|
||||
|
||||
match ui_type:
|
||||
case UIType.MainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.CogView4MainModel:
|
||||
ui_model_base = [BaseModelType.CogView4]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.FluxMainModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SD3MainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusion3]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SDXLMainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusionXL]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SDXLRefinerModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.VAEModel:
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.FluxVAEModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.LoRAModel:
|
||||
ui_model_type = [ModelType.LoRA]
|
||||
case UIType.ControlNetModel:
|
||||
ui_model_type = [ModelType.ControlNet]
|
||||
case UIType.IPAdapterModel:
|
||||
ui_model_type = [ModelType.IPAdapter]
|
||||
case UIType.T2IAdapterModel:
|
||||
ui_model_type = [ModelType.T2IAdapter]
|
||||
case UIType.T5EncoderModel:
|
||||
ui_model_type = [ModelType.T5Encoder]
|
||||
case UIType.CLIPEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
case UIType.CLIPLEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
ui_model_variant = [ClipVariantType.L]
|
||||
case UIType.CLIPGEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
ui_model_variant = [ClipVariantType.G]
|
||||
case UIType.SpandrelImageToImageModel:
|
||||
ui_model_type = [ModelType.SpandrelImageToImage]
|
||||
case UIType.ControlLoRAModel:
|
||||
ui_model_type = [ModelType.ControlLoRa]
|
||||
case UIType.SigLipModel:
|
||||
ui_model_type = [ModelType.SigLIP]
|
||||
case UIType.FluxReduxModel:
|
||||
ui_model_type = [ModelType.FluxRedux]
|
||||
case UIType.LlavaOnevisionModel:
|
||||
ui_model_type = [ModelType.LlavaOnevision]
|
||||
case UIType.Imagen3Model:
|
||||
ui_model_base = [BaseModelType.Imagen3]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Imagen4Model:
|
||||
ui_model_base = [BaseModelType.Imagen4]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.ChatGPT4oModel:
|
||||
ui_model_base = [BaseModelType.ChatGPT4o]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Gemini2_5Model:
|
||||
ui_model_base = [BaseModelType.Gemini2_5]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.FluxKontextModel:
|
||||
ui_model_base = [BaseModelType.FluxKontext]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Veo3Model:
|
||||
ui_model_base = [BaseModelType.Veo3]
|
||||
ui_model_type = [ModelType.Video]
|
||||
case UIType.RunwayModel:
|
||||
ui_model_base = [BaseModelType.Runway]
|
||||
ui_model_type = [ModelType.Video]
|
||||
case _:
|
||||
pass
|
||||
|
||||
did_migrate = False
|
||||
|
||||
if ui_model_type is not None:
|
||||
json_schema_extra["ui_model_type"] = [m.value for m in ui_model_type]
|
||||
did_migrate = True
|
||||
if ui_model_base is not None:
|
||||
json_schema_extra["ui_model_base"] = [m.value for m in ui_model_base]
|
||||
did_migrate = True
|
||||
if ui_model_format is not None:
|
||||
json_schema_extra["ui_model_format"] = [m.value for m in ui_model_format]
|
||||
did_migrate = True
|
||||
if ui_model_variant is not None:
|
||||
json_schema_extra["ui_model_variant"] = [m.value for m in ui_model_variant]
|
||||
did_migrate = True
|
||||
|
||||
return did_migrate
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
@@ -614,104 +455,51 @@ def InputField(
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: Optional[bool] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
ui_model_base: Optional[BaseModelType | list[BaseModelType]] = None,
|
||||
ui_model_type: Optional[ModelType | list[ModelType]] = None,
|
||||
ui_model_variant: Optional[ClipVariantType | ModelVariantType | list[ClipVariantType | ModelVariantType]] = None,
|
||||
ui_model_format: Optional[ModelFormat | list[ModelFormat]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field)
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
If the field is a `ModelIdentifierField`, use the `ui_model_[base|type|variant|format]` args to filter the model list
|
||||
in the Workflow Editor. Otherwise, use `ui_type` to provide extra type hints for the UI.
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
Don't use both `ui_type` and `ui_model_[base|type|variant|format]` - if both are provided, a warning will be
|
||||
logged and `ui_type` will be ignored.
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
Args:
|
||||
input: The kind of input this field requires.
|
||||
- `Input.Direct` means a value must be provided on instantiation.
|
||||
- `Input.Connection` means the value must be provided by a connection.
|
||||
- `Input.Any` means either will do.
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
ui_component: Optionally specifies a specific component to use in the UI. The UI will always render a suitable
|
||||
component, but sometimes you want something different than the default. For example, a `string` field will
|
||||
default to a single-line input, but you may want a multi-line textarea instead. In this case, you could use
|
||||
`UIComponent.Textarea`.
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
|
||||
ui_model_base: Specifies the base model architectures to filter the model list by in the Workflow Editor. For
|
||||
example, `ui_model_base=BaseModelType.StableDiffusionXL` will show only SDXL architecture models. This arg is
|
||||
only valid if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_type: Specifies the model type(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_type=ModelType.VAE` will show only VAE models. This arg is only valid if this Input field is
|
||||
annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_variant: Specifies the model variant(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_variant=ModelVariantType.Inpainting` will show only inpainting models. This arg is only valid if this
|
||||
Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_format: Specifies the model format(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_format=ModelFormat.Diffusers` will show only models in the diffusers format. This arg is only valid
|
||||
if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_choice_labels: Specifies the labels to use for the choices in an enum field. If omitted, the enum values
|
||||
will be used. This arg is only valid if the field is annotated with as a `Literal`. For example,
|
||||
`Literal["choice1", "choice2", "choice3"]` with `ui_choice_labels={"choice1": "Choice 1", "choice2": "Choice 2",
|
||||
"choice3": "Choice 3"}` will render a dropdown with the labels "Choice 1", "Choice 2" and "Choice 3".
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
if ui_component is not None:
|
||||
json_schema_extra_.ui_component = ui_component
|
||||
if ui_hidden is not None:
|
||||
json_schema_extra_.ui_hidden = ui_hidden
|
||||
if ui_order is not None:
|
||||
json_schema_extra_.ui_order = ui_order
|
||||
if ui_choice_labels is not None:
|
||||
json_schema_extra_.ui_choice_labels = ui_choice_labels
|
||||
if ui_model_base is not None:
|
||||
if isinstance(ui_model_base, list):
|
||||
json_schema_extra_.ui_model_base = ui_model_base
|
||||
else:
|
||||
json_schema_extra_.ui_model_base = [ui_model_base]
|
||||
if ui_model_type is not None:
|
||||
if isinstance(ui_model_type, list):
|
||||
json_schema_extra_.ui_model_type = ui_model_type
|
||||
else:
|
||||
json_schema_extra_.ui_model_type = [ui_model_type]
|
||||
if ui_model_variant is not None:
|
||||
if isinstance(ui_model_variant, list):
|
||||
json_schema_extra_.ui_model_variant = ui_model_variant
|
||||
else:
|
||||
json_schema_extra_.ui_model_variant = [ui_model_variant]
|
||||
if ui_model_format is not None:
|
||||
if isinstance(ui_model_format, list):
|
||||
json_schema_extra_.ui_model_format = ui_model_format
|
||||
else:
|
||||
json_schema_extra_.ui_model_format = [ui_model_format]
|
||||
if ui_type is not None:
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
@@ -741,7 +529,7 @@ def InputField(
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warning('"default_factory" is not supported, calling it now to set "default"')
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
@@ -783,7 +571,7 @@ def InputField(
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_unset=True),
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
@@ -812,20 +600,20 @@ def OutputField(
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization)
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
Args:
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
@@ -843,9 +631,9 @@ def OutputField(
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_type=ui_type,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
@@ -21,19 +21,17 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_control_lora_loader",
|
||||
title="Control LoRA - FLUX",
|
||||
title="Flux Control LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.1",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model,
|
||||
title="Control LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlLoRa,
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
@@ -3,15 +3,15 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
@@ -52,15 +52,14 @@ class FluxControlNetOutput(BaseInvocationOutput):
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -10,18 +10,17 @@ from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
@@ -32,8 +31,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
@@ -47,12 +46,11 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -63,9 +61,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="4.1.0",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
@@ -104,16 +103,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
redux_conditioning: FluxReduxConditioningField | list[FluxReduxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Redux conditioning tensor.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
fill_conditioning: FluxFillConditioningField | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Fill conditioning.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale_start_step: int = InputField(
|
||||
default=0,
|
||||
@@ -145,20 +134,11 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
|
||||
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
|
||||
# upstream nodes.
|
||||
|
||||
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference image).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -210,23 +190,11 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
|
||||
context=context,
|
||||
redux_cond_field=self.redux_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
text_conditioning=pos_text_conditionings,
|
||||
redux_conditioning=redux_conditionings,
|
||||
img_seq_len=packed_h * packed_w,
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(
|
||||
text_conditioning=neg_text_conditionings, redux_conditioning=[], img_seq_len=packed_h * packed_w
|
||||
)
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
@@ -275,19 +243,8 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
|
||||
img_cond: torch.Tensor | None = None
|
||||
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
|
||||
if is_flux_fill:
|
||||
img_cond = self._prep_flux_fill_img_cond(
|
||||
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
|
||||
)
|
||||
else:
|
||||
if self.fill_conditioning is not None:
|
||||
raise ValueError("fill_conditioning was provided, but the model is not a FLUX Fill model.")
|
||||
|
||||
if self.control_lora is not None:
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
@@ -296,6 +253,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
@@ -304,10 +262,10 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
@@ -328,21 +286,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
cfg_scale_end_step=self.cfg_scale_end_step,
|
||||
)
|
||||
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
context=context,
|
||||
kontext_conditioning=self.kontext_conditioning
|
||||
if isinstance(self.kontext_conditioning, list)
|
||||
else [self.kontext_conditioning],
|
||||
vae_field=self.controlnet_vae,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
@@ -400,14 +343,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Prepare Kontext conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if kontext_extension is not None:
|
||||
# Ensure batch sizes match
|
||||
kontext_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -423,8 +358,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -467,42 +400,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _load_redux_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
redux_cond_field: FluxReduxConditioningField | list[FluxReduxConditioningField] | None,
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> list[FluxReduxConditioning]:
|
||||
# Normalize to a list of FluxReduxConditioningFields.
|
||||
if redux_cond_field is None:
|
||||
return []
|
||||
|
||||
redux_cond_list = (
|
||||
[redux_cond_field] if isinstance(redux_cond_field, FluxReduxConditioningField) else redux_cond_field
|
||||
)
|
||||
|
||||
redux_conditionings: list[FluxReduxConditioning] = []
|
||||
for redux_cond_field in redux_cond_list:
|
||||
# Load the Redux conditioning tensor.
|
||||
redux_cond_data = context.tensors.load(redux_cond_field.conditioning.tensor_name)
|
||||
redux_cond_data.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if redux_cond_field.mask is not None:
|
||||
mask = context.tensors.load(redux_cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
redux_conditionings.append(FluxReduxConditioning(redux_embeddings=redux_cond_data, mask=mask))
|
||||
|
||||
return redux_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
@@ -713,70 +610,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
return pack(img_cond)
|
||||
|
||||
def _prep_flux_fill_img_cond(
|
||||
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Prepare the FLUX Fill conditioning. This method should be called iff the model is a FLUX Fill model.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
|
||||
"""
|
||||
# Validate inputs.
|
||||
if self.fill_conditioning is None:
|
||||
raise ValueError("A FLUX Fill model is being used without fill_conditioning.")
|
||||
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A FLUX Fill model is being used without controlnet_vae.")
|
||||
if self.control_lora is not None:
|
||||
raise ValueError(
|
||||
"A FLUX Fill model is being used, but a control_lora was provided. Control LoRAs are not compatible with FLUX Fill models."
|
||||
)
|
||||
|
||||
# Log input warnings related to FLUX Fill usage.
|
||||
if self.denoise_mask is not None:
|
||||
context.logger.warning(
|
||||
"Both fill_conditioning and a denoise_mask were provided. You probably meant to use one or the other."
|
||||
)
|
||||
if self.guidance < 25.0:
|
||||
context.logger.warning("A guidance value of ~30.0 is recommended for FLUX Fill models.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
|
||||
cond_img = cond_img.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask and resize it to the target image size.
|
||||
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
|
||||
# We expect mask to be a bool tensor with shape [1, H, W].
|
||||
assert mask.dtype == torch.bool
|
||||
assert mask.dim() == 3
|
||||
assert mask.shape[0] == 1
|
||||
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
|
||||
|
||||
# Prepare image conditioning.
|
||||
cond_img = cond_img * (1 - mask)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
|
||||
cond_img = pack(cond_img)
|
||||
|
||||
# Prepare mask conditioning.
|
||||
mask = mask[:, 0, :, :]
|
||||
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
|
||||
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
||||
mask = pack(mask)
|
||||
|
||||
# Merge image and mask conditioning.
|
||||
img_cond = torch.cat((cond_img, mask), dim=-1)
|
||||
return img_cond
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
@@ -899,10 +733,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
# The denoise function now handles Kontext conditioning correctly,
|
||||
# so we don't need to slice the latents here
|
||||
latents = state.latents.float()
|
||||
state.latents = unpack(latents, self.height, self.width).squeeze()
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxFillConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_fill_output")
|
||||
class FluxFillOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Fill invocation."""
|
||||
|
||||
fill_cond: FluxFillConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_fill",
|
||||
title="FLUX Fill Conditioning",
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxFillInvocation(BaseInvocation):
|
||||
"""Prepare the FLUX Fill conditioning data."""
|
||||
|
||||
image: ImageField = InputField(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = InputField(
|
||||
description="The bool inpainting mask. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxFillOutput:
|
||||
return FluxFillOutput(fill_cond=FluxFillConditioningField(image=self.image, mask=self.mask))
|
||||
@@ -4,8 +4,8 @@ from typing import List, Literal, Union
|
||||
from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
IPAdapterField,
|
||||
@@ -20,7 +20,6 @@ from invokeai.backend.model_manager.config import (
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -29,6 +28,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
"""Collects FLUX IP-Adapter info to pass to other nodes."""
|
||||
@@ -37,10 +37,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.IPAdapter,
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
|
||||
)
|
||||
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
|
||||
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxKontextConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_kontext_output")
|
||||
class FluxKontextOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Kontext invocation."""
|
||||
|
||||
kontext_cond: FluxKontextConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext",
|
||||
title="Kontext Conditioning - FLUX",
|
||||
tags=["conditioning", "kontext", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextInvocation(BaseInvocation):
|
||||
"""Prepares a reference image for FLUX Kontext conditioning."""
|
||||
|
||||
image: ImageField = InputField(description="The Kontext reference image.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
|
||||
"""Packages the provided image into a Kontext conditioning field."""
|
||||
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))
|
||||
@@ -3,13 +3,14 @@ from typing import Optional
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
@@ -20,26 +21,21 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: Optional[T5EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="Apply LoRA - FLUX",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.2.1",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
@@ -54,12 +50,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
@@ -72,8 +62,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
|
||||
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
|
||||
|
||||
output = FluxLoRALoaderOutput()
|
||||
|
||||
@@ -94,30 +82,23 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.t5_encoder is not None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
output.t5_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_collection_loader",
|
||||
title="Apply LoRA Collection - FLUX",
|
||||
title="FLUX LoRA Collection Loader",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.3.1",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a FLUX transformer."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
loras: LoRAField | list[LoRAField] = InputField(
|
||||
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
@@ -132,30 +113,13 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
output = FluxLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
|
||||
if self.t5_encoder is not None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
@@ -166,13 +130,14 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
if self.transformer is not None:
|
||||
if output.transformer is None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.clip is not None and output.clip is not None:
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
if self.t5_encoder is not None and output.t5_encoder is not None:
|
||||
output.t5_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
@@ -3,21 +3,18 @@ from typing import Literal
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
@@ -36,40 +33,34 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Main Model - FLUX",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.6",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.VAE,
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
@@ -83,8 +74,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
@@ -92,7 +83,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
import math
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxReduxConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.starter_models import siglip
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation_output("flux_redux_output")
|
||||
class FluxReduxOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Redux invocation."""
|
||||
|
||||
redux_cond: FluxReduxConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
|
||||
)
|
||||
|
||||
|
||||
DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_redux",
|
||||
title="FLUX Redux",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="2.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxReduxInvocation(BaseInvocation):
|
||||
"""Runs a FLUX Redux model to generate a conditioning tensor."""
|
||||
|
||||
image: ImageField = InputField(description="The FLUX Redux image prompt.")
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
redux_model: ModelIdentifierField = InputField(
|
||||
description="The FLUX Redux model to use.",
|
||||
title="FLUX Redux Model",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.FluxRedux,
|
||||
)
|
||||
downsampling_factor: int = InputField(
|
||||
ge=1,
|
||||
le=9,
|
||||
default=1,
|
||||
description="Redux Downsampling Factor (1-9)",
|
||||
)
|
||||
downsampling_function: DOWNSAMPLING_FUNCTIONS = InputField(
|
||||
default="area",
|
||||
description="Redux Downsampling Function",
|
||||
)
|
||||
weight: float = InputField(
|
||||
ge=0,
|
||||
le=1,
|
||||
default=1.0,
|
||||
description="Redux weight (0.0-1.0)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
encoded_x = self._siglip_encode(context, image)
|
||||
redux_conditioning = self._flux_redux_encode(context, encoded_x)
|
||||
if self.downsampling_factor > 1 or self.weight != 1.0:
|
||||
redux_conditioning = self._downsample_weight(context, redux_conditioning)
|
||||
|
||||
tensor_name = context.tensors.save(redux_conditioning)
|
||||
return FluxReduxOutput(
|
||||
redux_cond=FluxReduxConditioningField(conditioning=TensorField(tensor_name=tensor_name), mask=self.mask)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _downsample_weight(self, context: InvocationContext, redux_conditioning: torch.Tensor) -> torch.Tensor:
|
||||
# Downsampling derived from https://github.com/kaibioinfo/ComfyUI_AdvancedRefluxControl
|
||||
(b, t, h) = redux_conditioning.shape
|
||||
m = int(math.sqrt(t))
|
||||
if self.downsampling_factor > 1:
|
||||
redux_conditioning = redux_conditioning.view(b, m, m, h)
|
||||
redux_conditioning = torch.nn.functional.interpolate(
|
||||
redux_conditioning.transpose(1, -1),
|
||||
size=(m // self.downsampling_factor, m // self.downsampling_factor),
|
||||
mode=self.downsampling_function,
|
||||
)
|
||||
redux_conditioning = redux_conditioning.transpose(1, -1).reshape(b, -1, h)
|
||||
if self.weight != 1.0:
|
||||
redux_conditioning = redux_conditioning * self.weight * self.weight
|
||||
return redux_conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
|
||||
siglip_model_config = self._get_siglip_model(context)
|
||||
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
|
||||
assert isinstance(model, SiglipVisionModel)
|
||||
|
||||
model_abs_path = context.models.get_absolute_path(siglip_model_config)
|
||||
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
|
||||
assert isinstance(processor, SiglipImageProcessor)
|
||||
|
||||
siglip_pipeline = SigLipPipeline(processor, model)
|
||||
return siglip_pipeline.encode_image(
|
||||
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
|
||||
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
|
||||
assert isinstance(flux_redux, FluxReduxModel)
|
||||
dtype = next(flux_redux.parameters()).dtype
|
||||
encoded_x = encoded_x.to(dtype=dtype)
|
||||
return flux_redux(encoded_x)
|
||||
|
||||
def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
|
||||
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)
|
||||
|
||||
if not len(siglip_models) > 0:
|
||||
context.logger.warning(
|
||||
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
|
||||
)
|
||||
|
||||
# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
|
||||
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)
|
||||
|
||||
# Queue the job
|
||||
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)
|
||||
|
||||
# Wait for up to 10 minutes - model is ~3.5GB
|
||||
context._services.model_manager.install.wait_for_job(job, timeout=600)
|
||||
|
||||
siglip_models = context.models.search_by_attrs(
|
||||
name=siglip.name,
|
||||
base=BaseModelType.Any,
|
||||
type=ModelType.SigLIP,
|
||||
)
|
||||
|
||||
if len(siglip_models) == 0:
|
||||
context.logger.error("Error while fetching SigLIP for FLUX Redux")
|
||||
assert len(siglip_models) == 1
|
||||
|
||||
return siglip_models[0]
|
||||
@@ -1,10 +1,10 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple, Union
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
@@ -17,19 +17,20 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager import ModelFormat
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="Prompt - FLUX",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.2",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
@@ -70,50 +71,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
prompt = [self.prompt]
|
||||
|
||||
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_encoder_config = t5_encoder_info.config
|
||||
assert t5_encoder_config is not None
|
||||
|
||||
with (
|
||||
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
|
||||
model_is_quantized = False
|
||||
elif t5_encoder_config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
|
||||
|
||||
# Apply LoRA models to the T5 encoder.
|
||||
# Note: We apply the LoRA after the encoder has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=t5_text_encoder,
|
||||
patches=self._t5_lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=t5_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_t5_tokenization(context, t5_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
@@ -154,9 +120,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_clip_tokenization(context, clip_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
@@ -169,95 +132,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.t5_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _log_t5_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
|
||||
|
||||
# Tokenize the prompt using the same parameters as the model's text encoder.
|
||||
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True, # This is important for T5 to add the EOS token.
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("\u2581", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for token in tokens:
|
||||
if token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
used_tokens += 1
|
||||
elif token == tokenizer.pad_token:
|
||||
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
|
||||
continue
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
def _log_clip_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: CLIPTokenizer,
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a CLIP-based model."""
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
attention_mask = tokenized_output.attention_mask[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The CLIP tokenizer uses '</w>' to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("</w>", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for i, token in enumerate(tokens):
|
||||
if attention_mask[i] == 0:
|
||||
# Do not log padding tokens.
|
||||
continue
|
||||
|
||||
if token == tokenizer.bos_token:
|
||||
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
|
||||
elif token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
@@ -3,6 +3,7 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -17,15 +18,14 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="Latents to Image - FLUX",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -39,11 +39,22 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1090 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
|
||||
@@ -15,15 +15,14 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="Image to Latents - FLUX",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@@ -42,12 +41,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("ideal_size_output")
|
||||
@@ -19,16 +19,16 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"ideal_size",
|
||||
title="Ideal Size - SD1.5, SDXL",
|
||||
title="Ideal Size",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
version="1.0.6",
|
||||
version="1.0.3",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
"""Calculates the ideal size for generation to avoid duplication"""
|
||||
|
||||
width: int = InputField(default=1024, description="Final image width")
|
||||
height: int = InputField(default=576, description="Final image height")
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet)
|
||||
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
|
||||
multiplier: float = InputField(
|
||||
default=1.0,
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
|
||||
@@ -41,16 +41,11 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
aspect = self.width / self.height
|
||||
|
||||
if unet_config.base == BaseModelType.StableDiffusion1:
|
||||
dimension = 512
|
||||
elif unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
|
||||
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||
dimension = 1024
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {unet_config.base}")
|
||||
|
||||
dimension = dimension * self.multiplier
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
|
||||
|
||||
@@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
ColorField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -24,7 +23,6 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
@@ -163,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
crop: bool = InputField(default=False, description="Crop to base image dimensions")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
base_image = context.images.get_pil(self.base_image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = None
|
||||
if self.mask is not None:
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
mask = ImageOps.invert(mask)
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
mask = ImageOps.invert(mask.convert("L"))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@@ -178,11 +176,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
|
||||
# Create a temporary image to paste the image with transparency
|
||||
temp_image = Image.new("RGBA", new_image.size)
|
||||
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
new_image = Image.alpha_composite(new_image, temp_image)
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
if self.crop:
|
||||
base_w, base_h = base_image.size
|
||||
@@ -307,44 +301,14 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Split the image into RGBA channels
|
||||
r, g, b, a = image.split()
|
||||
|
||||
# Premultiply RGB channels by alpha
|
||||
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
|
||||
premultiplied_image.putalpha(a)
|
||||
|
||||
# Apply the blur
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blurred_image = premultiplied_image.filter(blur)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
# Split the blurred image into RGBA channels
|
||||
r, g, b, a_orig = blurred_image.split()
|
||||
|
||||
# Convert to float using NumPy. float 32/64 division are much faster than float 16
|
||||
r = numpy.array(r, dtype=numpy.float32)
|
||||
g = numpy.array(g, dtype=numpy.float32)
|
||||
b = numpy.array(b, dtype=numpy.float32)
|
||||
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
|
||||
|
||||
# Unpremultiply RGB channels by alpha
|
||||
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
|
||||
g /= a + 1e-6
|
||||
b /= a + 1e-6
|
||||
|
||||
# Convert back to PIL images
|
||||
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
|
||||
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
|
||||
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
|
||||
|
||||
# Merge back into a single image
|
||||
result_image = Image.merge("RGBA", (r, g, b, a_orig))
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
image_dto = context.images.save(image=blur_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -355,6 +319,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "unsharp_mask"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies an unsharp mask filter to an image"""
|
||||
@@ -842,7 +807,7 @@ CHANNEL_FORMATS = {
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.2.3",
|
||||
version="1.2.2",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
@@ -852,22 +817,18 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGBA")
|
||||
pil_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(image.convert(mode)).astype(int)
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
|
||||
if self.channel == "Hue (HSV)":
|
||||
# loop around the values because hue is special
|
||||
image_channel = (image_channel + self.offset) % 256
|
||||
else:
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
|
||||
|
||||
# Put the channel back into the image
|
||||
converted_image[:, :, channel_number] = image_channel
|
||||
@@ -875,10 +836,6 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
|
||||
# restore the alpha channel
|
||||
if self.channel != "Alpha (RGBA)":
|
||||
pil_image.putalpha(image.getchannel("A"))
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -906,7 +863,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.2.3",
|
||||
version="1.2.2",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Scale a specific color channel of an image."""
|
||||
@@ -917,14 +874,14 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGBA")
|
||||
pil_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(image.convert(mode)).astype(float)
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
|
||||
# Adjust the value, clipping to 0..255
|
||||
@@ -940,10 +897,6 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
|
||||
# restore the alpha channel
|
||||
if self.channel != "Alpha (RGBA)":
|
||||
pil_image.putalpha(image.getchannel("A"))
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -975,13 +928,13 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Canvas Paste Back",
|
||||
tags=["image", "combine"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
|
||||
|
||||
source_image: ImageField = InputField(description="The source image")
|
||||
target_image: ImageField = InputField(description="The target image")
|
||||
target_image: ImageField = InputField(default=None, description="The target image")
|
||||
mask: ImageField = InputField(
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
@@ -1009,10 +962,10 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@invocation(
|
||||
"mask_from_id",
|
||||
title="Mask from Segmented Image",
|
||||
title="Mask from ID",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate a mask for a particular color in an ID Map"""
|
||||
@@ -1022,24 +975,40 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
threshold: int = InputField(default=100, description="Threshold for color detection")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
|
||||
r, g, b, a = rgba_color
|
||||
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
|
||||
return hex_code
|
||||
|
||||
np_color = numpy.array(self.color.tuple())
|
||||
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
|
||||
if id_mask.mode != "RGB":
|
||||
id_mask = id_mask.convert("RGB")
|
||||
|
||||
# Can directly just use the tuple but I'll leave this rgba_to_hex here
|
||||
# incase anyone prefers using hex codes directly instead of the color picker
|
||||
hex_color_str = self.rgba_to_hex(color)
|
||||
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
|
||||
|
||||
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
|
||||
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
|
||||
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
|
||||
|
||||
# Create a mask based on the threshold and the distance calculated above
|
||||
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
|
||||
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
|
||||
|
||||
# Convert the mask back to PIL
|
||||
binary_mask_pil = Image.fromarray(binary_mask)
|
||||
|
||||
if self.invert:
|
||||
binary_mask_pil = ImageOps.invert(binary_mask_pil)
|
||||
return binary_mask_pil
|
||||
|
||||
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
|
||||
|
||||
if self.invert:
|
||||
mask = ImageOps.invert(mask)
|
||||
|
||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1050,7 +1019,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Deprecated,
|
||||
classification=Classification.Internal,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
@@ -1086,357 +1055,3 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
|
||||
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
|
||||
If the fade size is 0, the mask is returned as-is.
|
||||
"""
|
||||
|
||||
mask: ImageField = InputField(description="The mask to expand")
|
||||
threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
|
||||
fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.fade_size_px == 0:
|
||||
# If the fade size is 0, just return the mask as-is.
|
||||
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
np_mask = numpy.array(pil_mask)
|
||||
|
||||
# Threshold the mask to create a binary mask - 0 for black, 255 for white
|
||||
# If we don't threshold we can get some weird artifacts
|
||||
np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)
|
||||
|
||||
# Create a mask for the black region (1 where black, 0 otherwise)
|
||||
black_mask = (np_mask == 0).astype(numpy.uint8)
|
||||
|
||||
# Invert the black region
|
||||
bg_mask = 1 - black_mask
|
||||
|
||||
# Create a distance transform of the inverted mask
|
||||
dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)
|
||||
|
||||
# Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
|
||||
d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)
|
||||
|
||||
# Control points: x values (normalized distance) and corresponding fade pct y values.
|
||||
|
||||
# There are some magic numbers here that are used to create a smooth transition:
|
||||
# - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
|
||||
# - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
|
||||
# by 1px. This fixes an issue where artifacts can occur at the edge of the mask
|
||||
# - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
|
||||
# - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
|
||||
# - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)
|
||||
|
||||
# x values: 0 = mask edge, 1 = fade_size_px from edge
|
||||
x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
|
||||
# y values: 0 = black, 1 = white
|
||||
y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])
|
||||
|
||||
# Fit a cubic polynomial that smoothly passes through the control points
|
||||
coeffs = numpy.polyfit(x_control, y_control, 3)
|
||||
poly = numpy.poly1d(coeffs)
|
||||
|
||||
# Evaluate the polynomial
|
||||
feather = poly(d_norm)
|
||||
|
||||
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
|
||||
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
|
||||
# polynomial fit, which is a best approximation of the control points but not an exact match.
|
||||
|
||||
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
|
||||
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
|
||||
|
||||
# Force pixels at or beyond the fade distance to exactly 1.0
|
||||
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
|
||||
|
||||
# Clip any other values to ensure they're in the valid range [0,1]
|
||||
feather = numpy.clip(feather, 0, 1)
|
||||
|
||||
# Build final image.
|
||||
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
|
||||
|
||||
# Convert back to PIL, grayscale
|
||||
pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")
|
||||
|
||||
image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"apply_mask_to_image",
|
||||
title="Apply Mask to Image",
|
||||
tags=["image", "mask", "blend"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""
|
||||
Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
|
||||
The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="The image from which to extract the masked region")
|
||||
mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
|
||||
invert_mask: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to invert the mask before applying it",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load images
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.invert_mask:
|
||||
# Invert the mask if requested
|
||||
mask = ImageOps.invert(mask.copy())
|
||||
|
||||
# Combine the mask as the alpha channel of the image
|
||||
r, g, b, _ = image.split() # Split the image into RGB and alpha channels
|
||||
result_image = Image.merge("RGBA", (r, g, b, mask)) # Use the mask as the new alpha channel
|
||||
|
||||
# Save the resulting image
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_noise",
|
||||
title="Add Image Noise",
|
||||
tags=["image", "noise"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
)
|
||||
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add noise to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to add noise to")
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None, description="Optional mask determining where to apply noise (black=noise, white=no noise)"
|
||||
)
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
)
|
||||
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
|
||||
default="gaussian",
|
||||
description="The type of noise to add",
|
||||
)
|
||||
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
|
||||
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
|
||||
size: int = InputField(default=1, ge=1, description="The size of the noise points")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
|
||||
# Save out the alpha channel
|
||||
alpha = image.getchannel("A")
|
||||
|
||||
# Set the seed for numpy random
|
||||
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
|
||||
|
||||
if self.noise_type == "gaussian":
|
||||
if self.noise_color:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
|
||||
else:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
elif self.noise_type == "salt_and_pepper":
|
||||
if self.noise_color:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
else:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
|
||||
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
|
||||
(image.width, image.height), Image.Resampling.NEAREST
|
||||
)
|
||||
|
||||
# Create a noisy version of the input image
|
||||
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
|
||||
|
||||
# Apply mask if provided
|
||||
if self.mask is not None:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if mask_image.size != image.size:
|
||||
mask_image = mask_image.resize(image.size, Image.Resampling.LANCZOS)
|
||||
|
||||
result_image = image.copy()
|
||||
mask_image = ImageOps.invert(mask_image)
|
||||
result_image.paste(noisy_image, (0, 0), mask=mask_image)
|
||||
else:
|
||||
result_image = noisy_image
|
||||
|
||||
# Paste back the alpha channel from the original image
|
||||
result_image.putalpha(alpha)
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"crop_image_to_bounding_box",
|
||||
title="Crop Image to Bounding Box",
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
)
|
||||
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
bounding_box: BoundingBoxField | None = InputField(
|
||||
default=None, description="The bounding box to crop the image to"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
bounding_box = self.bounding_box.tuple() if self.bounding_box is not None else image.getbbox()
|
||||
|
||||
cropped_image = image.crop(bounding_box)
|
||||
|
||||
image_dto = context.images.save(image=cropped_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"paste_image_into_bounding_box",
|
||||
title="Paste Image into Bounding Box",
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
)
|
||||
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Paste the source image into the target image at the given bounding box.
|
||||
|
||||
The source image must be the same size as the bounding box, and the bounding box must fit within the target image."""
|
||||
|
||||
source_image: ImageField = InputField(description="The image to paste")
|
||||
target_image: ImageField = InputField(description="The image to paste into")
|
||||
bounding_box: BoundingBoxField = InputField(description="The bounding box to paste the image into")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
source_image = context.images.get_pil(self.source_image.image_name, mode="RGBA")
|
||||
target_image = context.images.get_pil(self.target_image.image_name, mode="RGBA")
|
||||
|
||||
bounding_box = self.bounding_box.tuple()
|
||||
|
||||
target_image.paste(source_image, bounding_box, source_image)
|
||||
|
||||
image_dto = context.images.save(image=target_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext_image_prep",
|
||||
title="FLUX Kontext Image Prep",
|
||||
tags=["image", "concatenate", "flux", "kontext"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
|
||||
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
description="The images to concatenate",
|
||||
min_length=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
use_preferred_resolution: bool = InputField(
|
||||
default=True, description="Use FLUX preferred resolutions for the first image"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
|
||||
# Step 1: Load all images
|
||||
pil_images = []
|
||||
for image_field in self.images:
|
||||
image = context.images.get_pil(image_field.image_name, mode="RGBA")
|
||||
pil_images.append(image)
|
||||
|
||||
# Step 2: Determine target resolution for the first image
|
||||
first_image = pil_images[0]
|
||||
width, height = first_image.size
|
||||
|
||||
if self.use_preferred_resolution:
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Find the closest preferred resolution for the first image
|
||||
_, target_width, target_height = min(
|
||||
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
# Apply BFL's scaling formula
|
||||
scaled_height = 2 * int(target_height / 16)
|
||||
final_height = 8 * scaled_height # This will be consistent for all images
|
||||
scaled_width = 2 * int(target_width / 16)
|
||||
first_width = 8 * scaled_width
|
||||
else:
|
||||
# Use original dimensions of first image, ensuring divisibility by 16
|
||||
final_height = 16 * (height // 16)
|
||||
first_width = 16 * (width // 16)
|
||||
# Ensure minimum dimensions
|
||||
if final_height < 16:
|
||||
final_height = 16
|
||||
if first_width < 16:
|
||||
first_width = 16
|
||||
|
||||
# Step 3: Process and resize all images with consistent height
|
||||
processed_images = []
|
||||
total_width = 0
|
||||
|
||||
for i, image in enumerate(pil_images):
|
||||
if i == 0:
|
||||
# First image uses the calculated dimensions
|
||||
final_width = first_width
|
||||
else:
|
||||
# Subsequent images maintain aspect ratio with the same height
|
||||
img_aspect_ratio = image.width / image.height
|
||||
# Calculate width that maintains aspect ratio at the target height
|
||||
calculated_width = int(final_height * img_aspect_ratio)
|
||||
# Ensure width is divisible by 16 for proper VAE encoding
|
||||
final_width = 16 * (calculated_width // 16)
|
||||
# Ensure minimum width
|
||||
if final_width < 16:
|
||||
final_width = 16
|
||||
|
||||
# Resize image to calculated dimensions
|
||||
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
|
||||
processed_images.append(resized_image)
|
||||
total_width += final_width
|
||||
|
||||
# Step 4: Concatenate images horizontally
|
||||
concatenated_image = Image.new("RGB", (total_width, final_height))
|
||||
x_offset = 0
|
||||
for img in processed_images:
|
||||
concatenated_image.paste(img, (x_offset, 0))
|
||||
x_offset += img.width
|
||||
|
||||
# Save the concatenated image
|
||||
image_dto = context.images.save(image=concatenated_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -27,15 +27,14 @@ from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
|
||||
@invocation(
|
||||
"i2l",
|
||||
title="Image to Latents - SD1.5, SDXL",
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.1.1",
|
||||
version="1.1.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@@ -53,24 +52,11 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
cls,
|
||||
vae_info: LoadedModel,
|
||||
upcast: bool,
|
||||
tiled: bool,
|
||||
image_tensor: torch.Tensor,
|
||||
tile_size: int = 0,
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
tile_size=tile_size if tiled else None,
|
||||
fp32=upcast,
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
@@ -127,7 +113,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@@ -135,11 +120,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info,
|
||||
upcast=self.fp32,
|
||||
tiled=self.tiled or context.config.get().force_tiled_decode,
|
||||
image_tensor=image_tensor,
|
||||
tile_size=self.tile_size,
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -127,16 +127,13 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
return infilled
|
||||
|
||||
|
||||
LAMA_MODEL_URL = "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt"
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
with self._context.models.load_remote_model(
|
||||
source=LAMA_MODEL_URL,
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user