mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 14:28:03 -05:00
Compare commits
120 Commits
speed-up-h
...
feature/mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b85f2bc87d | ||
|
|
b06d63fb34 | ||
|
|
5278a64301 | ||
|
|
4de4473c0f | ||
|
|
2c28a850ca | ||
|
|
6dada3326d | ||
|
|
2dfdc02ec8 | ||
|
|
1f19db4c6a | ||
|
|
7c150c27f2 | ||
|
|
248916c190 | ||
|
|
be8b99eed5 | ||
|
|
2ad0752582 | ||
|
|
ba1f8878dd | ||
|
|
bc524026f9 | ||
|
|
ad7c571983 | ||
|
|
8559c6a392 | ||
|
|
c7904a32f4 | ||
|
|
17f5484f5b | ||
|
|
86a372b02f | ||
|
|
2e9aa9391d | ||
|
|
0c8112cf28 | ||
|
|
019898c7be | ||
|
|
2b1ff8d196 | ||
|
|
79fb691b4d | ||
|
|
560ae17e21 | ||
|
|
2bd1ab2f1c | ||
|
|
ed43472582 | ||
|
|
6e5e9176c0 | ||
|
|
4c6bcdbc18 | ||
|
|
20e6d4fa3c | ||
|
|
8e51392910 | ||
|
|
0b1c2acd61 | ||
|
|
86ac55ab5f | ||
|
|
3e82f63c7e | ||
|
|
631f6cae19 | ||
|
|
0845a0ed84 | ||
|
|
46c8ce9fed | ||
|
|
13a9ea35b5 | ||
|
|
94e8d1b6d5 | ||
|
|
2b1dc74080 | ||
|
|
f7e558d165 | ||
|
|
d959276217 | ||
|
|
dfcf38be91 | ||
|
|
fbded1c0f2 | ||
|
|
ad2926a24c | ||
|
|
34d5cad4c9 | ||
|
|
60aa3d4893 | ||
|
|
5c2884569e | ||
|
|
a1307b9f2e | ||
|
|
f505ec64ba | ||
|
|
f22eb368a3 | ||
|
|
96ae22c7e0 | ||
|
|
f5447cdc23 | ||
|
|
c76a6bd65f | ||
|
|
6c4eeaa569 | ||
|
|
1bbd13ead7 | ||
|
|
321b939d0e | ||
|
|
8fb77e431e | ||
|
|
083a4f3faa | ||
|
|
2005411f7e | ||
|
|
ba7b1b2665 | ||
|
|
b7ffd36cc6 | ||
|
|
199ddd6623 | ||
|
|
a7207ed8cf | ||
|
|
6bb2dda3f1 | ||
|
|
c1e5cd5893 | ||
|
|
ff249a2315 | ||
|
|
c58f8c3269 | ||
|
|
ed772a7107 | ||
|
|
cb0b389b4b | ||
|
|
8892df1d97 | ||
|
|
bc5f356390 | ||
|
|
bcb85e100d | ||
|
|
1f27ddc07d | ||
|
|
7a2b606001 | ||
|
|
83ddcc5f3a | ||
|
|
55fa785561 | ||
|
|
06429028c8 | ||
|
|
8b6e322697 | ||
|
|
54a67459bf | ||
|
|
7fe5283e74 | ||
|
|
fe0391c86b | ||
|
|
25386a76ef | ||
|
|
fd30cb4d90 | ||
|
|
0266946d3d | ||
|
|
a7f91b3e01 | ||
|
|
de0b72528c | ||
|
|
2932652787 | ||
|
|
db6bc7305a | ||
|
|
a5db204629 | ||
|
|
8e2b61e19f | ||
|
|
a3faa3792a | ||
|
|
c16eba78ab | ||
|
|
1a191c4655 | ||
|
|
e36d925bce | ||
|
|
b1ba18b3d1 | ||
|
|
aff46759f9 | ||
|
|
d7b7dcc7fe | ||
|
|
889a26c5b6 | ||
|
|
b4c774896a | ||
|
|
afbe889d35 | ||
|
|
9c1e52b1ef | ||
|
|
3f5ab02da9 | ||
|
|
bf48e8a03a | ||
|
|
e52434cb99 | ||
|
|
483bdbcb9f | ||
|
|
ae421fb4ab | ||
|
|
cc295a9f0a | ||
|
|
a7e23af9c6 | ||
|
|
3de4390711 | ||
|
|
3ceee2b2b2 | ||
|
|
5c7ed24aab | ||
|
|
183c9c4799 | ||
|
|
8baf3f78a2 | ||
|
|
ac2eb16a65 | ||
|
|
4aa7bee4b9 | ||
|
|
7e5ba2795e | ||
|
|
97a6c6eea7 | ||
|
|
f0e60a4ba2 | ||
|
|
aa089e8108 |
33
.github/actions/install-frontend-deps/action.yml
vendored
33
.github/actions/install-frontend-deps/action.yml
vendored
@@ -1,33 +0,0 @@
|
||||
name: install frontend dependencies
|
||||
description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: setup cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
shell: bash
|
||||
working-directory: invokeai/frontend/web
|
||||
28
.github/pr_labels.yml
vendored
28
.github/pr_labels.yml
vendored
@@ -1,59 +1,59 @@
|
||||
root:
|
||||
Root:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '*'
|
||||
|
||||
python-deps:
|
||||
PythonDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'pyproject.toml'
|
||||
|
||||
python:
|
||||
Python:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
|
||||
python-tests:
|
||||
PythonTests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
ci-cd:
|
||||
CICD:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: .github/**
|
||||
|
||||
docker:
|
||||
Docker:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docker/**
|
||||
|
||||
installer:
|
||||
Installer:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: installer/**
|
||||
|
||||
docs:
|
||||
Documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
|
||||
invocations:
|
||||
Invocations:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/invocations/**'
|
||||
|
||||
backend:
|
||||
Backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/backend/**'
|
||||
|
||||
api:
|
||||
Api:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/api/**'
|
||||
|
||||
services:
|
||||
Services:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/services/**'
|
||||
|
||||
frontend-deps:
|
||||
FrontendDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*/package.json'
|
||||
- '**/*/pnpm-lock.yaml'
|
||||
|
||||
frontend:
|
||||
Frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/frontend/web/**'
|
||||
|
||||
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@@ -11,7 +11,7 @@ on:
|
||||
- 'docker/docker-entrypoint.sh'
|
||||
- 'workflows/build-container.yml'
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
|
||||
45
.github/workflows/build-installer.yml
vendored
45
.github/workflows/build-installer.yml
vendored
@@ -1,45 +0,0 @@
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-installer:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <2 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install pypa/build
|
||||
run: pip install --upgrade build
|
||||
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
||||
68
.github/workflows/frontend-checks.yml
vendored
68
.github/workflows/frontend-checks.yml
vendored
@@ -1,68 +0,0 @@
|
||||
# Runs frontend code quality checks.
|
||||
#
|
||||
# Checks for changes to frontend files before running the checks.
|
||||
# When manually triggered or when called from another workflow, always runs the checks.
|
||||
|
||||
name: 'frontend checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: tsc
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: dpdm
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: eslint
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: prettier
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: knip
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:knip'
|
||||
shell: bash
|
||||
48
.github/workflows/frontend-tests.yml
vendored
48
.github/workflows/frontend-tests.yml
vendored
@@ -1,48 +0,0 @@
|
||||
# Runs frontend tests.
|
||||
#
|
||||
# Checks for changes to frontend files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'frontend tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: vitest
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm test:no-watch'
|
||||
shell: bash
|
||||
12
.github/workflows/label-pr.yml
vendored
12
.github/workflows/label-pr.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: 'label PRs'
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
labeler:
|
||||
@@ -9,10 +9,8 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: label PRs
|
||||
uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
configuration-path: .github/pr_labels.yml
|
||||
configuration-path: .github/pr_labels.yml
|
||||
43
.github/workflows/lint-frontend.yml
vendored
Normal file
43
.github/workflows/lint-frontend.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Lint frontend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
- name: Install dependencies
|
||||
run: 'pnpm install --prefer-frozen-lockfile'
|
||||
- name: Typescript
|
||||
run: 'pnpm run lint:tsc'
|
||||
- name: Madge
|
||||
run: 'pnpm run lint:madge'
|
||||
- name: ESLint
|
||||
run: 'pnpm run lint:eslint'
|
||||
- name: Prettier
|
||||
run: 'pnpm run lint:prettier'
|
||||
54
.github/workflows/mkdocs-material.yml
vendored
54
.github/workflows/mkdocs-material.yml
vendored
@@ -1,49 +1,51 @@
|
||||
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md
|
||||
|
||||
name: mkdocs
|
||||
|
||||
name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
- 'refs/heads/main'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: set cache id
|
||||
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: use cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs build \
|
||||
--clean \
|
||||
--verbose
|
||||
|
||||
- name: install dependencies
|
||||
run: python -m pip install ".[docs]"
|
||||
|
||||
- name: build & deploy
|
||||
run: mkdocs gh-deploy --force
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
--clean \
|
||||
--force
|
||||
|
||||
67
.github/workflows/pypi-release.yml
vendored
Normal file
67
.github/workflows/pypi-release.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish_package:
|
||||
description: 'Publish build on PyPi? [true/false]'
|
||||
required: true
|
||||
default: 'false'
|
||||
|
||||
jobs:
|
||||
build-and-release:
|
||||
if: github.repository == 'invoke-ai/InvokeAI'
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
TWINE_NON_INTERACTIVE: 1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm run build
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Install python dependencies
|
||||
run: pip install --upgrade build twine
|
||||
|
||||
- name: Build python package
|
||||
run: python3 -m build
|
||||
|
||||
- name: Upload build as workflow artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Check distribution
|
||||
run: twine check dist/*
|
||||
|
||||
- name: Check PyPI versions
|
||||
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
run: |
|
||||
pip install --upgrade requests
|
||||
python -c "\
|
||||
import scripts.pypi_helper; \
|
||||
EXISTS=scripts.pypi_helper.local_on_pypi(); \
|
||||
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
|
||||
|
||||
- name: Publish build on PyPi
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
|
||||
run: twine upload dist/*
|
||||
64
.github/workflows/python-checks.yml
vendored
64
.github/workflows/python-checks.yml
vendored
@@ -1,64 +0,0 @@
|
||||
# Runs python code quality checks.
|
||||
#
|
||||
# Checks for changes to python files before running the checks.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
name: 'python checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff format --check .
|
||||
shell: bash
|
||||
94
.github/workflows/python-tests.yml
vendored
94
.github/workflows/python-tests.yml
vendored
@@ -1,94 +0,0 @@
|
||||
# Runs python tests on a matrix of python versions and platforms.
|
||||
#
|
||||
# Checks for changes to python files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'python tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pytest
|
||||
96
.github/workflows/release.yml
vendored
96
.github/workflows/release.yml
vendored
@@ -1,96 +0,0 @@
|
||||
# Main release workflow. Triggered on tag push or manual trigger.
|
||||
#
|
||||
# - Runs all code checks and tests
|
||||
# - Verifies the app version matches the tag version.
|
||||
# - Builds the installer and build, uploading them as artifacts.
|
||||
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
|
||||
#
|
||||
# See docs/RELEASE.md for more information on the release process.
|
||||
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-version:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check python version
|
||||
uses: samuelcolvin/check-python-version@v4
|
||||
id: check-python-version
|
||||
with:
|
||||
version_file_path: invokeai/version/invokeai_version.py
|
||||
|
||||
frontend-checks:
|
||||
uses: ./.github/workflows/frontend-checks.yml
|
||||
|
||||
frontend-tests:
|
||||
uses: ./.github/workflows/frontend-tests.yml
|
||||
|
||||
python-checks:
|
||||
uses: ./.github/workflows/python-checks.yml
|
||||
|
||||
python-tests:
|
||||
uses: ./.github/workflows/python-tests.yml
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: testpypi
|
||||
url: https://test.pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to TestPyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
|
||||
publish-pypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
24
.github/workflows/style-checks.yml
vendored
Normal file
24
.github/workflows/style-checks.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: style checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install ruff
|
||||
|
||||
- run: ruff check --output-format=github .
|
||||
- run: ruff format --check .
|
||||
129
.github/workflows/test-invoke-pip.yml
vendored
Normal file
129
.github/workflows/test-invoke-pip.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
name: Test invoke.py pip
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
# - '3.9'
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
# - name: run invokeai-configure
|
||||
# env:
|
||||
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||
# run: >
|
||||
# invokeai-configure
|
||||
# --yes
|
||||
# --default_only
|
||||
# --full-precision
|
||||
# # can't use fp16 weights without a GPU
|
||||
|
||||
# - name: run invokeai
|
||||
# id: run-invokeai
|
||||
# env:
|
||||
# # Set offline mode to make sure configure preloaded successfully.
|
||||
# HF_HUB_OFFLINE: 1
|
||||
# HF_DATASETS_OFFLINE: 1
|
||||
# TRANSFORMERS_OFFLINE: 1
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# run: >
|
||||
# invokeai
|
||||
# --no-patchmatch
|
||||
# --no-nsfw_checker
|
||||
# --precision=float32
|
||||
# --always_use_cpu
|
||||
# --use_memory_db
|
||||
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
# --from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
# - name: Archive results
|
||||
# env:
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: results
|
||||
# path: ${{ env.INVOKEAI_OUTDIR }}
|
||||
@@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
|
||||
overrides:
|
||||
- files: '*.md'
|
||||
options:
|
||||
proseWrap: preserve
|
||||
proseWrap: always
|
||||
printWidth: 80
|
||||
parser: markdown
|
||||
cursorOffset: -1
|
||||
|
||||
43
Makefile
43
Makefile
@@ -6,45 +6,33 @@ default: help
|
||||
help:
|
||||
@echo Developer commands:
|
||||
@echo
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
|
||||
# Runs ruff, fixing all errors it can fix and formatting
|
||||
ruff-unsafe:
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
|
||||
# Runs mypy, using the config in pyproject.toml
|
||||
mypy:
|
||||
mypy scripts/invokeai-web.py
|
||||
mypy scripts/invokeai-web.py
|
||||
|
||||
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
||||
# (many files are ignored by the config, so this is useful for checking all files)
|
||||
mypy-all:
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
|
||||
# Run the unit tests
|
||||
test:
|
||||
pytest ./tests
|
||||
|
||||
# Install the pnpm modules needed for the front end
|
||||
frontend-install:
|
||||
rm -rf invokeai/frontend/web/node_modules
|
||||
cd invokeai/frontend/web && pnpm install
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
|
||||
# Build the frontend
|
||||
frontend-build:
|
||||
@@ -54,9 +42,6 @@ frontend-build:
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
||||
142
docs/RELEASE.md
142
docs/RELEASE.md
@@ -1,142 +0,0 @@
|
||||
# Release Process
|
||||
|
||||
The app is published in twice, in different build formats.
|
||||
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
|
||||
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
|
||||
#### `publish-testpypi` Job
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
## Publish the GitHub Release with installer
|
||||
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
|
||||
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
|
||||
5. Publish the pre-release.
|
||||
6. Announce the pre-release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
|
||||
## Manual Release
|
||||
|
||||
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
[manually]: #manual-release
|
||||
@@ -32,6 +32,7 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@@ -62,21 +63,23 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model.
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@@ -91,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@@ -170,7 +174,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@@ -213,27 +217,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@@ -251,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@@ -268,28 +272,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@@ -308,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str
|
||||
#### version() -> str:
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@@ -333,7 +337,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@@ -348,7 +352,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@@ -366,31 +370,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
* Registering a model config record for a model already located on the
|
||||
- Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@@ -423,8 +427,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@@ -439,6 +443,7 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@@ -452,15 +457,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@@ -486,9 +491,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@@ -508,13 +513,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@@ -523,7 +528,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@@ -550,7 +555,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@@ -561,6 +566,7 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@@ -619,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@@ -654,6 +661,7 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@@ -667,13 +675,14 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@@ -693,13 +702,14 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur _before_ the running event.
|
||||
that downloading events occur *before* the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@@ -742,13 +752,14 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@@ -763,11 +774,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@@ -851,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@@ -941,7 +953,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@@ -962,7 +974,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | **Description** |
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@@ -979,7 +991,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@@ -1028,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@@ -1041,12 +1053,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@@ -1108,11 +1120,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@@ -1143,9 +1155,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@@ -1160,13 +1172,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@@ -1175,6 +1187,7 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@@ -1187,7 +1200,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@@ -1262,7 +1275,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@@ -1315,6 +1328,7 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@@ -1334,6 +1348,7 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@@ -1400,6 +1415,7 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@@ -1418,12 +1434,13 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
@@ -1518,16 +1535,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1550,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@@ -1563,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@@ -1571,14 +1590,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@@ -1590,15 +1609,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1705,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
# Invocation API
|
||||
|
||||
Each invocation's `invoke` method is provided a single arg - the Invocation
|
||||
Context.
|
||||
|
||||
This object provides access to various methods, used to interact with the
|
||||
application. Loading and saving images, logging messages, etc.
|
||||
|
||||
!!! warning ""
|
||||
|
||||
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
|
||||
|
||||
```py
|
||||
class MyInvocation(BaseInvocation):
|
||||
...
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image_pil = context.images.get_pil(image_name)
|
||||
# Do something to the image
|
||||
image_dto = context.images.save(image_pil)
|
||||
# Log a message
|
||||
context.logger.info(f"Did something cool, image saved!")
|
||||
...
|
||||
```
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
||||
options:
|
||||
members: false
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ImagesInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.TensorsInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ModelsInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.LoggerInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ConfigInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.UtilInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.BoardsInterface
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -1,148 +0,0 @@
|
||||
# Invoke v4.0.0 Nodes API Migration guide
|
||||
|
||||
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
|
||||
by nodes, both core and custom.
|
||||
|
||||
## Motivation
|
||||
|
||||
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
|
||||
as a library. That is to say, it didn't have any explicitly public API, and node
|
||||
authors had to work with the unstable internal application API.
|
||||
|
||||
v4.0.0 introduces a stable public API for nodes.
|
||||
|
||||
## Changes
|
||||
|
||||
There are two node-author-facing changes:
|
||||
|
||||
1. Import Paths
|
||||
1. Invocation Context API
|
||||
|
||||
### Import Paths
|
||||
|
||||
All public objects are now exported from `invokeai.invocation_api`:
|
||||
|
||||
```py
|
||||
# Old
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
# New
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
```
|
||||
|
||||
It's possible that we've missed some classes you need in your node. Please let
|
||||
us know if that's the case.
|
||||
|
||||
### Invocation Context API
|
||||
|
||||
Most nodes utilize the Invocation Context, an object that is passed to the
|
||||
`invoke` that provides access to data and services a node may need.
|
||||
|
||||
Until now, that object and the services it exposed were internal. Exposing them
|
||||
to nodes means that changes to our internal implementation could break nodes.
|
||||
The methods on the services are also often fairly complicated and allowed nodes
|
||||
to footgun.
|
||||
|
||||
In v4.0.0, this object has been refactored to be much simpler.
|
||||
|
||||
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
|
||||
|
||||
!!! warning ""
|
||||
|
||||
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
|
||||
|
||||
#### Improved Service Methods
|
||||
|
||||
The biggest offender was the image save method:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
# New
|
||||
image_dto = context.images.save(image=image)
|
||||
```
|
||||
|
||||
Other methods are simplified, or enhanced with additional functionality:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image = context.services.images.get_pil_image(image_name)
|
||||
|
||||
# New
|
||||
image = context.images.get_pil(image_name)
|
||||
image_cmyk = context.images.get_pil(image_name, "CMYK")
|
||||
```
|
||||
|
||||
We also had some typing issues around tensors:
|
||||
|
||||
```py
|
||||
# Old
|
||||
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
|
||||
context.services.latents.save(latents_name, data)
|
||||
|
||||
# New - separate methods for tensors and conditioning data w/ correct typing
|
||||
# Also, the service generates the names
|
||||
tensor_name = context.tensors.save(tensor)
|
||||
tensor = context.tensors.load(tensor_name)
|
||||
# For conditioning
|
||||
cond_name = context.conditioning.save(cond_data)
|
||||
cond_data = context.conditioning.load(cond_name)
|
||||
```
|
||||
|
||||
#### Output Construction
|
||||
|
||||
Core Outputs have builder functions right on them - no need to manually
|
||||
construct these objects, or use an extra utility:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image_output = ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
noise_output = NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
cond_output = ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
# New
|
||||
image_output = ImageOutput.build(image_dto)
|
||||
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
cond_output = ConditioningOutput.build(conditioning_name)
|
||||
```
|
||||
|
||||
You can still create the objects using constructors if you want, but we suggest
|
||||
using the builder methods.
|
||||
@@ -32,7 +32,6 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Latent Upscale](#latent-upscale)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
@@ -291,13 +290,6 @@ View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Latent Upscale
|
||||
|
||||
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
|
||||
|
||||
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
|
||||
|
||||
--------------------------------
|
||||
### Load Video Frame
|
||||
|
||||
@@ -354,21 +346,12 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
|
||||
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
|
||||
- `Metadata From Image` - Provides Metadata from an image
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
|
||||
- `Metadata To Bool` - Extracts Bool types from metadata
|
||||
- `Metadata To Model` - Extracts model types from metadata
|
||||
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
|
||||
- `Metadata To LoRAs` - Extracts Loras from metadata.
|
||||
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
|
||||
- `Metadata To ControlNets` - Extracts ControNets from metadata
|
||||
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
|
||||
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
|
||||
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
|
||||
- `Metadata From Image` - Provides Metadata from an image.
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata.
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata.
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
|
||||
|
||||
|
||||
@@ -19,8 +19,6 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
||||
5
docs/requirements-mkdocs.txt
Normal file
5
docs/requirements-mkdocs.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
mkdocs
|
||||
mkdocs-material>=8, <9
|
||||
mkdocs-git-revision-date-localized-plugin
|
||||
mkdocs-redirects==1.2.0
|
||||
|
||||
5
docs/stylesheets/extra.css
Normal file
5
docs/stylesheets/extra.css
Normal file
@@ -0,0 +1,5 @@
|
||||
:root {
|
||||
--md-primary-fg-color: #35A4DB;
|
||||
--md-primary-fg-color--light: #35A4DB;
|
||||
--md-primary-fg-color--dark: #35A4DB;
|
||||
}
|
||||
@@ -2,18 +2,22 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function is_bin_in_path {
|
||||
builtin type -P "$1" &>/dev/null
|
||||
}
|
||||
|
||||
function git_show {
|
||||
git show -s --format=oneline --abbrev-commit "$1" | cat
|
||||
}
|
||||
|
||||
if [[ ! -z "${VIRTUAL_ENV}" ]]; then
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
@@ -22,63 +26,31 @@ fi
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
|
||||
# Some machines only have `python3` in PATH, others have `python` - make an alias.
|
||||
# We can use a function to approximate an alias within a non-interactive shell.
|
||||
if ! is_bin_in_path python && is_bin_in_path python3; then
|
||||
function python {
|
||||
python3 "$@"
|
||||
}
|
||||
fi
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
VERSION="v${VERSION}"
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo -e "${BCYAN}CI environment detected${RESET}"
|
||||
echo
|
||||
else
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
fi
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show HEAD
|
||||
echo
|
||||
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
echo "Installing frontend dependencies..."
|
||||
echo
|
||||
pnpm i --frozen-lockfile
|
||||
echo
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo "Building frontend without checks..."
|
||||
# In CI, we have already done the frontend checks and can just build
|
||||
pnpm vite build
|
||||
else
|
||||
echo "Running checks and building frontend..."
|
||||
# This runs all the frontend checks and builds
|
||||
pnpm build
|
||||
fi
|
||||
echo
|
||||
popd
|
||||
|
||||
# ---------------------- BACKEND ----------------------
|
||||
|
||||
echo
|
||||
echo "Building wheel..."
|
||||
echo
|
||||
|
||||
# install the 'build' package in the user site packages, if needed
|
||||
# could be improved by using a temporary venv, but it's tiny and harmless
|
||||
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
|
||||
pip install --user build
|
||||
fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python3 -m build --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
@@ -106,28 +78,10 @@ chmod a+x InvokeAI-Installer/install.sh
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
FILENAME=InvokeAI-installer-$VERSION.zip
|
||||
|
||||
# Zip everything up
|
||||
zip -r ${FILENAME} InvokeAI-Installer
|
||||
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
|
||||
|
||||
echo
|
||||
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
|
||||
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
|
||||
|
||||
# clean up, but only if we are not in a github action
|
||||
if [[ -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Cleaning up intermediate build files..."
|
||||
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
|
||||
fi
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Setting GitHub action outputs..."
|
||||
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
|
||||
fi
|
||||
# clean up
|
||||
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
|
||||
|
||||
exit 0
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function does_tag_exist {
|
||||
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
||||
@@ -23,40 +23,49 @@ function git_show {
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v${MAJOR_VERSION}-latest"
|
||||
|
||||
if does_tag_exist $VERSION; then
|
||||
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
||||
git_show_ref tags/$VERSION
|
||||
echo
|
||||
fi
|
||||
if does_tag_exist $LATEST_TAG; then
|
||||
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
||||
git_show_ref tags/$LATEST_TAG
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
echo
|
||||
|
||||
echo -e "${BGREEN}git remote -v${RESET}:"
|
||||
git remote -v
|
||||
echo
|
||||
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
||||
read -e -p 'y/n [n]: ' input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
echo
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..."
|
||||
git push origin :refs/tags/$VERSION
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
||||
git push --delete origin $VERSION
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..."
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
||||
if ! git tag -fa $VERSION; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo -e "Pushing updated tags to origin remote..."
|
||||
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
||||
git push --delete origin $LATEST_TAG
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
||||
git tag -fa $LATEST_TAG
|
||||
|
||||
echo -e "Pushing updated tags to remote..."
|
||||
git push origin --tags
|
||||
fi
|
||||
exit 0
|
||||
|
||||
@@ -4,6 +4,7 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
@@ -15,22 +16,24 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
|
||||
from ..services.board_images.board_images_default import BoardImagesService
|
||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.graph import GraphExecutionState
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@@ -72,8 +75,6 @@ class ApiDependencies:
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@@ -84,7 +85,7 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService()
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
@@ -95,15 +96,17 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
queue = MemoryInvocationQueue()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
@@ -114,19 +117,20 @@ class ApiDependencies:
|
||||
board_images=board_images,
|
||||
board_records=board_records,
|
||||
boards=boards,
|
||||
bulk_download=bulk_download,
|
||||
configuration=configuration,
|
||||
events=events,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
image_files=image_files,
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
queue=queue,
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
|
||||
@@ -2,7 +2,7 @@ import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
@@ -375,67 +375,16 @@ async def unstar_images_in_list(
|
||||
|
||||
class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
default=None, description="The message to display to the user when images begin downloading"
|
||||
)
|
||||
bulk_download_item_name: Optional[str] = Field(
|
||||
default=None, description="The name of the bulk download item for which events will be emitted"
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||
)
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
||||
async def download_images_from_list(
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: Optional[list[str]] = Body(
|
||||
default=None, description="The list of names of images to download", embed=True
|
||||
),
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
board_id: Optional[str] = Body(
|
||||
default=None, description="The board from which image should be downloaded", embed=True
|
||||
default=None, description="The board from which image should be downloaded from", embed=True
|
||||
),
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||
|
||||
background_tasks.add_task(
|
||||
ApiDependencies.invoker.services.bulk_download.handler,
|
||||
image_names,
|
||||
board_id,
|
||||
bulk_download_item_id,
|
||||
)
|
||||
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
"/download/{bulk_download_item_name}",
|
||||
methods=["GET"],
|
||||
operation_id="get_bulk_download_item",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the complete bulk download item",
|
||||
"content": {"application/zip": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="application/zip",
|
||||
filename=bulk_download_item_name,
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
# return ImagesDownloaded(response="Your images are downloading")
|
||||
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -29,15 +30,13 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@@ -47,6 +46,15 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@@ -57,16 +65,19 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "string",
|
||||
"config": "string",
|
||||
"key": "string",
|
||||
"hash": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"converted_at": 0,
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@@ -75,12 +86,50 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@@ -112,33 +161,9 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/get_by_attrs",
|
||||
operation_id="get_model_records_by_attrs",
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def get_model_records_by_attrs(
|
||||
name: str = Query(description="The name of the model"),
|
||||
type: ModelType = Query(description="The type of the model"),
|
||||
base: BaseModelType = Query(description="The base model of the model"),
|
||||
) -> AnyModelConfig:
|
||||
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
|
||||
model manager, which identified models by a combination of name, base and type."""
|
||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
|
||||
base_model=base, model_type=type, model_name=name
|
||||
)
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
return configs[0]
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
@@ -158,92 +183,68 @@ async def get_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
path: str = Field(description="Path to the model")
|
||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/scan_folder",
|
||||
operation_id="scan_for_models",
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {"description": "Directory scanned successfully"},
|
||||
400: {"description": "Invalid directory path"},
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[FoundModel],
|
||||
)
|
||||
async def scan_for_models(
|
||||
scan_path: str = Query(description="Directory path to search for models", default=None),
|
||||
) -> List[FoundModel]:
|
||||
path = pathlib.Path(scan_path)
|
||||
if not scan_path or not path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"The search path '{scan_path}' does not exist or is not directory",
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
search = ModelSearch()
|
||||
try:
|
||||
found_model_paths = search.search(path)
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
|
||||
# If the search path includes the main models directory, we need to exclude core models from the list.
|
||||
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
|
||||
# without needing to crawl the filesystem.
|
||||
core_models_path = pathlib.Path(models_path, "core").resolve()
|
||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
resolved_installed_model_paths: list[str] = []
|
||||
installed_model_sources: list[str] = []
|
||||
|
||||
# This call lists all installed models.
|
||||
for model in installed_models:
|
||||
path = pathlib.Path(model.path)
|
||||
# If the model has a source, we need to add it to the list of installed sources.
|
||||
if model.source:
|
||||
installed_model_sources.append(model.source)
|
||||
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
|
||||
# the models path before resolving.
|
||||
if not path.is_absolute():
|
||||
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
|
||||
continue
|
||||
resolved_installed_model_paths.append(str(path.resolve()))
|
||||
|
||||
scan_results: list[FoundModel] = []
|
||||
|
||||
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
|
||||
for p in non_core_model_paths:
|
||||
path = str(p)
|
||||
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
|
||||
found_model = FoundModel(path=path, is_installed=is_installed)
|
||||
scan_results.append(found_model)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while searching the directory: {e}",
|
||||
)
|
||||
return scan_results
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
@@ -262,13 +263,15 @@ async def scan_for_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -278,85 +281,16 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model(
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -377,67 +311,47 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
204: {"description": "Model image deleted successfully"},
|
||||
404: {"description": "Model image not found"},
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=204,
|
||||
status_code=201,
|
||||
)
|
||||
async def delete_model_image(
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
model_images.delete(key)
|
||||
logger.info(f"Deleted model image: {key}")
|
||||
return
|
||||
except UnknownModelException as e:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="install_model",
|
||||
"/heuristic_import",
|
||||
operation_id="heuristic_import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
@@ -446,14 +360,12 @@ async def delete_model_image(
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
async def heuristic_import(
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "string", "description": "string"},
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
@@ -490,8 +402,106 @@ async def install_model(
|
||||
result: ModelInstallJob = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}
|
||||
```
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
```
|
||||
{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
* "model_install_running"
|
||||
* "model_install_completed"
|
||||
* "model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@@ -507,10 +517,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@@ -524,8 +534,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@@ -534,7 +545,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@@ -554,7 +565,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@@ -572,8 +583,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@@ -626,7 +637,6 @@ async def convert_model(
|
||||
Note that during the conversion process the key and model hash will change.
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
@@ -643,7 +653,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
@@ -652,8 +662,7 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
store.update_model(key, config=model_config)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@@ -662,7 +671,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"original_hash": model_config.original_hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@@ -670,6 +679,10 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@@ -681,66 +694,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
276
invokeai/app/api/routers/sessions.py
Normal file
276
invokeai/app/api/routers/sessions.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from fastapi import HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from ...services.shared.graph import GraphExecutionState
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/",
|
||||
# operation_id="create_session",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid json"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def create_session(
|
||||
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Creates a new session, optionally initializing it with an invocation graph"""
|
||||
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
# return session
|
||||
|
||||
|
||||
# @session_router.get(
|
||||
# "/",
|
||||
# operation_id="list_sessions",
|
||||
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def list_sessions(
|
||||
# page: int = Query(default=0, description="The page of results to get"),
|
||||
# per_page: int = Query(default=10, description="The number of results per page"),
|
||||
# query: str = Query(default="", description="The query string to search for"),
|
||||
# ) -> PaginatedResults[GraphExecutionState]:
|
||||
# """Gets a list of sessions, optionally searching"""
|
||||
# if query == "":
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
# else:
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
# return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/{session_id}/nodes",
|
||||
# operation_id="add_node",
|
||||
# responses={
|
||||
# 200: {"model": str},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The node to add"
|
||||
# ),
|
||||
# ) -> str:
|
||||
# """Adds a node to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.add_node(node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session.id
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.put(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="update_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def update_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node in the graph"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The new node"
|
||||
# ),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Updates a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.update_node(node_path, node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="delete_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node to delete"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.delete_node(node_path)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/{session_id}/edges",
|
||||
# operation_id="add_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# edge: Edge = Body(description="The edge to add"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Adds an edge to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.add_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# # TODO: the edge being in the path here is really ugly, find a better solution
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
# operation_id="delete_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
# from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
# to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
# to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes an edge from the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# edge = Edge(
|
||||
# source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
# )
|
||||
# session.delete_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.put(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="invoke_session",
|
||||
# responses={
|
||||
# 200: {"model": None},
|
||||
# 202: {"description": "The invocation is queued"},
|
||||
# 400: {"description": "The session has no invocations ready to invoke"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def invoke_session(
|
||||
# queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
# session_id: str = Path(description="The id of the session to invoke"),
|
||||
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# if session.is_complete():
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
# return Response(status_code=202)
|
||||
|
||||
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="cancel_session_invoke",
|
||||
# responses={202: {"description": "The invocation is canceled"}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def cancel_session_invoke(
|
||||
# session_id: str = Path(description="The id of the session to cancel"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# ApiDependencies.invoker.cancel(session_id)
|
||||
# return Response(status_code=202)
|
||||
@@ -12,26 +12,16 @@ class SocketIO:
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@@ -49,18 +39,3 @@ class SocketIO:
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@@ -17,7 +19,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -38,7 +39,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@@ -50,6 +50,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
sessions,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
@@ -58,7 +59,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
@@ -72,25 +72,9 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@@ -113,7 +97,21 @@ app.add_middleware(
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
@@ -153,8 +151,6 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
@@ -177,6 +173,7 @@ def custom_openapi() -> dict[str, Any]:
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
|
||||
@@ -8,26 +8,13 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
@@ -97,7 +84,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
@@ -110,14 +96,10 @@ class BaseInvocationOutput(BaseModel):
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
return cls._typeadapter
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation outputs."""
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
@@ -166,7 +148,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@@ -179,14 +160,10 @@ class BaseInvocation(ABC, BaseModel):
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation types."""
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -17,10 +26,16 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@@ -46,7 +61,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: CLIPField = InputField(
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
@@ -54,16 +69,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -71,10 +82,21 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -82,9 +104,8 @@ class CompelInvocation(BaseInvocation):
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -127,18 +148,14 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: CLIPField,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -163,7 +180,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@@ -172,10 +189,25 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.models.load_by_attrs(
|
||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -183,10 +215,8 @@ class SDXLPromptInvocationBase:
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -253,8 +283,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@@ -340,7 +370,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@@ -370,10 +400,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""CLIP skip node output"""
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -383,15 +413,15 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return CLIPSkipInvocationOutput(
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -52,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@@ -90,7 +95,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
|
||||
@@ -199,7 +199,6 @@ class DenoiseMaskField(BaseModel):
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
@@ -228,7 +227,7 @@ class ConditioningField(BaseModel):
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
|
||||
@@ -22,7 +22,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from .baseinvocation import BaseInvocation, Classification, invocation
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
)
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
|
||||
@@ -930,40 +934,3 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_paste_back",
|
||||
title="Canvas Paste Back",
|
||||
tags=["image", "combine"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
|
||||
|
||||
source_image: ImageField = InputField(description="The source image")
|
||||
target_image: ImageField = InputField(default=None, description="The target image")
|
||||
mask: ImageField = InputField(
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||
|
||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||
mask_array = numpy.array(mask)
|
||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||
if self.mask_blur > 0:
|
||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
return ImageOps.invert(mask.convert("L"))
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
source_image = context.images.get_pil(self.source_image.image_name)
|
||||
target_image = context.images.get_pil(self.target_image.image_name)
|
||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
|
||||
source_image.paste(target_image, (0, 0), mask)
|
||||
|
||||
image_dto = context.images.save(image=source_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -11,17 +11,25 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@@ -54,7 +62,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = InputField(
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
@@ -82,18 +90,18 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||
image_encoder_model=image_encoder_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
||||
@@ -23,10 +23,9 @@ from diffusers.models.attention_processor import (
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -76,7 +75,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelField, UNetField, VAEField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@@ -119,7 +118,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
@@ -129,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
@@ -154,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@@ -170,87 +169,17 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
gradient=False,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(
|
||||
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||
)
|
||||
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
else:
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelField,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@@ -375,6 +304,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -384,6 +314,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
@@ -462,7 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@@ -524,10 +455,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@@ -538,7 +470,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
@@ -578,8 +509,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@@ -675,9 +606,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None, False
|
||||
return None, None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
@@ -686,7 +617,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
return 1 - mask, masked_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -713,7 +644,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@@ -732,13 +663,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@@ -791,7 +721,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@@ -799,7 +732,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
@@ -832,7 +764,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@@ -843,8 +775,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
@@ -1010,7 +942,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@@ -1066,7 +998,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
||||
@@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -20,8 +17,10 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@@ -31,20 +30,10 @@ class MetadataItemField(BaseModel):
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class ModelMetadataField(BaseModel):
|
||||
"""Model Metadata Field"""
|
||||
|
||||
key: str
|
||||
hash: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@@ -52,7 +41,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelMetadataField = Field(
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
@@ -62,33 +51,6 @@ class IPAdapterMetadataField(BaseModel):
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@@ -152,7 +114,7 @@ GENERATION_MODES = Literal[
|
||||
]
|
||||
|
||||
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
|
||||
class CoreMetadataInvocation(BaseInvocation):
|
||||
"""Collects core generation metadata into a MetadataField"""
|
||||
|
||||
@@ -178,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
@@ -197,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[ModelMetadataField] = InputField(
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@@ -228,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
@@ -260,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
as_dict["app_version"] = __version__
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -16,33 +16,33 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelField(BaseModel):
|
||||
key: str = Field(description="Key of the model")
|
||||
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelField = Field(description="Info to load vae submodel")
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
@@ -74,6 +74,18 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
@@ -84,40 +96,62 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
key = self.model.key
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoRALoaderOutput(BaseInvocationOutput):
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@@ -125,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoRALoaderOutput()
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -168,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -183,10 +222,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@@ -194,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoRALoaderOutput()
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -254,11 +299,17 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelField = InputField(
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
@@ -270,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
@@ -278,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -297,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
vae: Optional[VaeField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
|
||||
@@ -299,13 +299,9 @@ class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
|
||||
) -> "DenoiseMaskOutput":
|
||||
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
|
||||
return cls(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
|
||||
),
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import CLIPField, ModelField, UNetField, VAEField
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("sdxl_refiner_model_loader_output")
|
||||
@@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -72,8 +101,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -10,14 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@@ -52,7 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = InputField(
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BulkDownloadBase(ABC):
|
||||
"""Responsible for creating a zip file containing the images specified by the given image names or board id."""
|
||||
|
||||
@abstractmethod
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
|
||||
:param image_names: A list of image names to include in the zip file.
|
||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
"""
|
||||
Get the path to the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
:return: The path to the bulk download file.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||
"""
|
||||
Generate an item ID for a bulk download item.
|
||||
|
||||
:param board_id: The ID of the board whose name is to be included in the item id.
|
||||
:return: The generated item ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
"""
|
||||
Delete the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
"""
|
||||
@@ -1,25 +0,0 @@
|
||||
DEFAULT_BULK_DOWNLOAD_ID = "default"
|
||||
|
||||
|
||||
class BulkDownloadException(Exception):
|
||||
"""Exception raised when a bulk download fails."""
|
||||
|
||||
def __init__(self, message="Bulk download failed"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadTargetException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download target is not found."""
|
||||
|
||||
def __init__(self, message="The bulk download target was not found"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadParametersException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download parameter is invalid."""
|
||||
|
||||
def __init__(self, message="No image names or board ID provided"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
@@ -1,157 +0,0 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
from zipfile import ZipFile
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import (
|
||||
DEFAULT_BULK_DOWNLOAD_ID,
|
||||
BulkDownloadException,
|
||||
BulkDownloadParametersException,
|
||||
BulkDownloadTargetException,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def __init__(self):
|
||||
self._temp_directory = TemporaryDirectory()
|
||||
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
|
||||
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
) -> None:
|
||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||
bulk_download_item_id = bulk_download_item_id or uuid_string()
|
||||
bulk_download_item_name = bulk_download_item_id + ".zip"
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
|
||||
try:
|
||||
image_dtos: list[ImageDTO] = []
|
||||
|
||||
if board_id:
|
||||
image_dtos = self._board_handler(board_id)
|
||||
elif image_names:
|
||||
image_dtos = self._image_handler(image_names)
|
||||
else:
|
||||
raise BulkDownloadParametersException()
|
||||
|
||||
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
except (
|
||||
ImageRecordNotFoundException,
|
||||
BoardRecordNotFoundException,
|
||||
BulkDownloadException,
|
||||
BulkDownloadParametersException,
|
||||
) as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
except Exception as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
self._invoker.services.logger.error("Problem bulk downloading images.")
|
||||
raise e
|
||||
|
||||
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
|
||||
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||
|
||||
def _board_handler(self, board_id: str) -> list[ImageDTO]:
|
||||
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
return self._image_handler(image_names)
|
||||
|
||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
|
||||
|
||||
def _get_clean_board_name(self, board_id: str) -> str:
|
||||
if board_id == "none":
|
||||
return "Uncategorized"
|
||||
|
||||
return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name)
|
||||
|
||||
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||
|
||||
:return: The name of the zip file.
|
||||
"""
|
||||
zip_file_name = bulk_download_item_id + ".zip"
|
||||
zip_file_path = self._bulk_downloads_folder / (zip_file_name)
|
||||
|
||||
with ZipFile(zip_file_path, "w") as zip_file:
|
||||
for image_dto in image_dtos:
|
||||
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
|
||||
image_disk_path = self._invoker.services.images.get_path(image_dto.image_name)
|
||||
zip_file.write(image_disk_path, arcname=image_zip_path)
|
||||
|
||||
return str(zip_file_name)
|
||||
|
||||
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
|
||||
def _clean_string_to_path_safe(self, s: str) -> str:
|
||||
"""Clean a string to be path safe."""
|
||||
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
|
||||
|
||||
def _signal_job_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has started."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has completed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has failed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._temp_directory.cleanup()
|
||||
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
path = self.get_path(bulk_download_item_name)
|
||||
Path(path).unlink()
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
path = str(self._bulk_downloads_folder / bulk_download_item_name)
|
||||
if not self._is_valid_path(path):
|
||||
raise BulkDownloadTargetException()
|
||||
return path
|
||||
|
||||
def _is_valid_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for a bulk download."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
@@ -156,7 +156,6 @@ class InvokeAISettings(BaseSettings):
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
"conf_path",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,6 +30,7 @@ InvokeAI:
|
||||
lora_dir: null
|
||||
embedding_dir: null
|
||||
controlnet_dir: null
|
||||
conf_path: configs/models.yaml
|
||||
models_dir: models
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
db_dir: databases
|
||||
@@ -122,6 +123,7 @@ a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
@@ -161,12 +163,12 @@ two configs are kept in separate sections of the config file:
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -235,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# PATHS
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
@@ -256,7 +259,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
@@ -299,7 +301,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Init file for download queue."""
|
||||
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
@@ -16,7 +16,6 @@ from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
@@ -25,14 +24,6 @@ class EventServiceBase:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
@@ -213,6 +204,52 @@ class EventServiceBase:
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
@@ -357,7 +394,6 @@ class EventServiceBase:
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
@@ -377,7 +413,6 @@ class EventServiceBase:
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -392,7 +427,7 @@ class EventServiceBase:
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
|
||||
@@ -402,7 +437,11 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
payload={
|
||||
"source": source,
|
||||
"total_bytes": total_bytes,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_cancelled(self, source: str) -> None:
|
||||
@@ -416,7 +455,12 @@ class EventServiceBase:
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
def emit_model_install_error(
|
||||
self,
|
||||
source: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
|
||||
@@ -426,45 +470,9 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
@@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
pass
|
||||
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,243 @@
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
GESStatsNotFoundError,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_processor_common import CanceledException
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
# LS - this will probably break
|
||||
# but the idea is to enable multithreading up to the number of available
|
||||
# GPUs. Nodes will block on model loading if no GPU is free.
|
||||
self.__threadLimit = BoundedSemaphore(invoker.services.model_manager.gpu_count)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs={"stop_event": self.__stop_event},
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
profiler = (
|
||||
Profiler(
|
||||
logger=self.__invoker.services.logger,
|
||||
output_dir=self.__invoker.services.configuration.profiles_path,
|
||||
prefix=self.__invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self.__invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
|
||||
def stats_cleanup(graph_execution_state_id: str) -> None:
|
||||
if profiler:
|
||||
profile_path = profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.__invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
|
||||
)
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
|
||||
profiler.start(profile_id=queue_item.graph_execution_state_id)
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
context_data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
session_id=graph_id,
|
||||
workflow=queue_item.workflow,
|
||||
source_node_id=source_node_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
batch_id=queue_item.session_queue_batch_id,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
services=self.__invoker.services,
|
||||
context_data=context_data,
|
||||
)
|
||||
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
logger.error(error)
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.session_queue_batch_id,
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
from .invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
__cancellations: dict[str, float]
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = {}
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while (
|
||||
isinstance(item, InvocationQueueItem)
|
||||
and item.graph_execution_state_id in self.__cancellations
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
||||
):
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||
del self.__cancellations[graph_execution_state_id]
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
if graph_execution_state_id not in self.__cancellations:
|
||||
self.__cancellations[graph_execution_state_id] = time.time()
|
||||
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
return graph_execution_state_id in self.__cancellations
|
||||
@@ -16,7 +16,6 @@ if TYPE_CHECKING:
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
from .boards.boards_base import BoardServiceABC
|
||||
from .bulk_download.bulk_download_base import BulkDownloadBase
|
||||
from .config import InvokeAIAppConfig
|
||||
from .download import DownloadQueueServiceBase
|
||||
from .events.events_base import EventServiceBase
|
||||
@@ -24,12 +23,15 @@ if TYPE_CHECKING:
|
||||
from .image_records.image_records_base import ImageRecordStorageBase
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
@@ -43,17 +45,18 @@ class InvocationServices:
|
||||
board_image_records: "BoardImageRecordStorageBase",
|
||||
boards: "BoardServiceABC",
|
||||
board_records: "BoardRecordStorageBase",
|
||||
bulk_download: "BulkDownloadBase",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
@@ -67,17 +70,18 @@ class InvocationServices:
|
||||
self.board_image_records = board_image_records
|
||||
self.boards = boards
|
||||
self.board_records = board_records
|
||||
self.bulk_download = bulk_download
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService()
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
@@ -30,7 +30,7 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import ContextManager
|
||||
from typing import Iterator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||
@@ -50,7 +50,7 @@ class InvocationStatsServiceBase(ABC):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> ContextManager[None]:
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
@@ -60,8 +60,12 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self):
|
||||
"""Reset all stored statistics."""
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Reset all statistics for the indicated graph.
|
||||
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Iterator
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
|
||||
from .invocation_stats_base import InvocationStatsServiceBase
|
||||
@@ -41,7 +42,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
@@ -50,6 +51,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
|
||||
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
|
||||
self._prune_stale_stats()
|
||||
|
||||
# Record state before the invocation.
|
||||
start_time = time.time()
|
||||
start_ram = psutil.Process().memory_info().rss
|
||||
@@ -74,9 +78,42 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def _prune_stale_stats(self) -> None:
|
||||
"""Check all graphs being tracked and prune any that have completed/errored.
|
||||
|
||||
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
|
||||
for now we call this function periodically to prevent them from accumulating.
|
||||
"""
|
||||
to_prune: list[str] = []
|
||||
for graph_execution_state_id in self._stats:
|
||||
try:
|
||||
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
|
||||
except ItemNotFoundError:
|
||||
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
|
||||
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
|
||||
continue
|
||||
|
||||
if not graph_execution_state.is_complete():
|
||||
# The graph is still running, don't prune it.
|
||||
continue
|
||||
|
||||
to_prune.append(graph_execution_state_id)
|
||||
|
||||
for graph_execution_state_id in to_prune:
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
|
||||
if len(to_prune) > 0:
|
||||
logger.info(f"Pruned stale graph stats for {to_prune}.")
|
||||
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
try:
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class Invoker:
|
||||
@@ -13,6 +18,51 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
session_queue_id: str,
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
session_queue_id=session_queue_id,
|
||||
session_queue_item_id=session_queue_item_id,
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
workflow=workflow,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, "start", None)
|
||||
@@ -35,3 +85,5 @@ class Invoker:
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
"""Gets the URL to fetch a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
"""Saves a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
||||
@@ -1,20 +0,0 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
||||
@@ -1,79 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
from .model_images_common import (
|
||||
ModelImageFileDeleteException,
|
||||
ModelImageFileNotFoundException,
|
||||
ModelImageFileSaveException,
|
||||
)
|
||||
|
||||
|
||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, model_images_folder: Path):
|
||||
self._model_images_folder = model_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise ModelImageFileNotFoundException from e
|
||||
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._model_images_folder / (model_key + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self._model_images_folder / (model_key + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
path = self.get_path(model_key)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
return self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
||||
@@ -18,16 +18,16 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
@@ -150,20 +150,12 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
@@ -185,12 +177,6 @@ class ModelInstallJob(BaseModel):
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
@@ -198,10 +184,7 @@ class ModelInstallJob(BaseModel):
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
@@ -212,9 +195,10 @@ class ModelInstallJob(BaseModel):
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
@@ -236,11 +220,6 @@ class ModelInstallJob(BaseModel):
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
@@ -266,6 +245,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@@ -352,7 +332,6 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@@ -398,7 +377,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, and/or `image_size`.
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
||||
@@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@@ -21,17 +21,14 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
@@ -39,14 +36,12 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@@ -96,6 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@@ -144,7 +140,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@@ -154,17 +149,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.current_hash
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
|
||||
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
|
||||
)
|
||||
try:
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
@@ -172,6 +163,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
@@ -184,14 +177,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@@ -281,20 +273,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed: List[str] = []
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||
return []
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._models_installed.clear()
|
||||
search.search(scan_dir)
|
||||
return list(self._models_installed)
|
||||
@@ -380,24 +366,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_errored(job)
|
||||
|
||||
elif (
|
||||
job.waiting or job.downloads_done
|
||||
job.waiting or job.downloading
|
||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@@ -455,13 +438,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@@ -482,8 +465,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
model.current_hash = new_hash
|
||||
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||
self.record_store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@@ -535,15 +525,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
key = self._create_key()
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
@@ -552,12 +548,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
if hasattr(info, "config"):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
@@ -577,15 +573,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
inplace=source.inplace,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@@ -613,17 +607,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@@ -638,7 +630,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
source: ModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@@ -666,7 +658,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@@ -745,14 +737,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||
if all(x.complete for x in install_job.download_parts):
|
||||
# now enqueue job for actual installation into the models directory
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
@@ -778,7 +769,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"{download_job.source}: model download cancelled")
|
||||
self._logger.warning(f"Download {download_job.source} cancelled.")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
@@ -825,7 +816,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
@@ -838,7 +828,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
|
||||
@@ -847,17 +837,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
||||
@@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
@@ -39,6 +40,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@@ -46,6 +48,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
@@ -94,20 +101,22 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@@ -13,6 +16,7 @@ from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
@@ -28,10 +32,9 @@ class ModelManagerServiceBase(ABC):
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
db: SqliteDatabase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -66,3 +69,37 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -14,7 +16,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@@ -60,6 +62,61 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are using."""
|
||||
return self.load.gpu_count
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
@@ -67,7 +124,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device = choose_torch_device(),
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -78,10 +134,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram_cache_size,
|
||||
max_vram_cache_size=app_config.vram_cache_size,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
|
||||
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
||||
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Init file for model record services."""
|
||||
|
||||
from .model_records_base import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
||||
@@ -6,19 +6,20 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@@ -59,34 +60,11 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@@ -110,12 +88,13 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -130,15 +109,38 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
:param hash: Hash of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -215,3 +217,21 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
||||
@@ -39,11 +39,12 @@ Typical usage:
|
||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -53,11 +54,12 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@@ -68,7 +70,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@@ -77,13 +79,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@@ -93,19 +96,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@@ -113,12 +120,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@@ -126,7 +133,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
return self.get_model(key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@@ -140,7 +147,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -152,20 +159,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@@ -192,7 +200,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -203,21 +211,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
@@ -228,7 +221,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -254,8 +247,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@@ -272,13 +266,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@@ -287,7 +282,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@@ -298,13 +293,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@@ -313,35 +308,83 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
select count(*) from model_config;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@@ -352,7 +395,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
||||
@@ -4,17 +4,3 @@ from pydantic import BaseModel, Field
|
||||
class SessionProcessorStatus(BaseModel):
|
||||
is_started: bool = Field(description="Whether the session processor is started")
|
||||
is_processing: bool = Field(description="Whether a session is being processed")
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
@@ -7,271 +6,136 @@ from typing import Optional
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker: Invoker = invoker
|
||||
self.__queue_item: Optional[SessionQueueItem] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
self.__resume_event = ThreadEvent()
|
||||
self.__stop_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||
# the profiler will create a new profile for each session.
|
||||
self._profiler = (
|
||||
Profiler(
|
||||
logger=self._invoker.services.logger,
|
||||
output_dir=self._invoker.services.configuration.profiles_path,
|
||||
prefix=self._invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self._invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
|
||||
self._thread = Thread(
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
target=self.__process,
|
||||
kwargs={
|
||||
"stop_event": self._stop_event,
|
||||
"poll_now_event": self._poll_now_event,
|
||||
"resume_event": self._resume_event,
|
||||
"cancel_event": self._cancel_event,
|
||||
"stop_event": self.__stop_event,
|
||||
"poll_now_event": self.__poll_now_event,
|
||||
"resume_event": self.__resume_event,
|
||||
},
|
||||
)
|
||||
self._thread.start()
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
self.__poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if event_name == "session_canceled" or event_name == "queue_cleared":
|
||||
# These both mean we should cancel the current session.
|
||||
self._cancel_event.set()
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self._resume_event.is_set():
|
||||
self._resume_event.set()
|
||||
if not self.__resume_event.is_set():
|
||||
self.__resume_event.set()
|
||||
return self.get_status()
|
||||
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
if self._resume_event.is_set():
|
||||
self._resume_event.clear()
|
||||
if self.__resume_event.is_set():
|
||||
self.__resume_event.clear()
|
||||
return self.get_status()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=self._queue_item is not None,
|
||||
is_started=self.__resume_event.is_set(),
|
||||
is_processing=self.__queue_item is not None,
|
||||
)
|
||||
|
||||
def _process(
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
poll_now_event: ThreadEvent,
|
||||
resume_event: ThreadEvent,
|
||||
cancel_event: ThreadEvent,
|
||||
):
|
||||
# Outermost processor try block; any unhandled exception is a fatal processor error
|
||||
try:
|
||||
self._thread_semaphore.acquire()
|
||||
stop_event.clear()
|
||||
resume_event.set()
|
||||
cancel_event.clear()
|
||||
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
if self._queue_item is not None and resume_event.is_set():
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
poll_now_event.set()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(
|
||||
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.session_queue.cancel_queue_item(
|
||||
queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
# Reset the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
# Immediately poll for next queue item
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
# Fatal error in processor, log and pass - we're done here
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self.__queue_item = None
|
||||
self.__threadLimit.release()
|
||||
|
||||
@@ -60,7 +60,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
@@ -429,6 +429,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@@ -470,6 +471,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
@@ -521,6 +523,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
|
||||
92
invokeai/app/services/shared/default_graphs.py
Normal file
92
invokeai/app/services/shared/default_graphs.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
|
||||
from ...invocations.compel import CompelInvocation
|
||||
from ...invocations.image import ImageNSFWBlurInvocation
|
||||
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
||||
from ...invocations.noise import NoiseInvocation
|
||||
from ...invocations.primitives import IntegerInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
|
||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
)
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=graph,
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = []
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -5,14 +5,8 @@ import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetJsonSchemaHandler,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
@@ -182,6 +176,10 @@ class NodeIdMismatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSubGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class CyclicalGraphError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -190,6 +188,25 @@ class UnknownGraphValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Create and use an Empty output?
|
||||
@invocation_output("graph_output")
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("graph", version="1.0.0")
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
|
||||
@invocation_output("iterate_output")
|
||||
class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
@@ -243,73 +260,21 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
||||
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@@ -321,21 +286,41 @@ class Graph(BaseModel):
|
||||
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
|
||||
"""Returns the graph and node id for a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if not isinstance(node, GraphInvocation):
|
||||
# There's more node path left but this isn't a graph - failure
|
||||
raise NodeNotFoundError("Node path terminated early at a non-graph node")
|
||||
|
||||
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
"""Deletes a node from a graph"""
|
||||
|
||||
try:
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
|
||||
# Delete edges for this node
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
for edge in input_edges:
|
||||
self.delete_edge(edge)
|
||||
for edge_graph, _, edge in input_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
for edge in output_edges:
|
||||
self.delete_edge(edge)
|
||||
for edge_graph, _, edge in output_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
del self.nodes[node_id]
|
||||
del graph.nodes[node_id]
|
||||
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
@@ -385,6 +370,13 @@ class Graph(BaseModel):
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
try:
|
||||
gn.graph.validate_self()
|
||||
except Exception as e:
|
||||
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
|
||||
|
||||
# Validate that all edges match nodes and fields in the graph
|
||||
for edge in self.edges:
|
||||
source_node = self.nodes.get(edge.source.node_id, None)
|
||||
@@ -446,6 +438,7 @@ class Graph(BaseModel):
|
||||
except (
|
||||
DuplicateNodeIdError,
|
||||
NodeIdMismatchError,
|
||||
InvalidSubGraphError,
|
||||
NodeNotFoundError,
|
||||
NodeFieldNotFoundError,
|
||||
CyclicalGraphError,
|
||||
@@ -466,7 +459,7 @@ class Graph(BaseModel):
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
@@ -533,90 +526,171 @@ class Graph(BaseModel):
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
try:
|
||||
_ = self.get_node(node_id)
|
||||
return True
|
||||
n = self.get_node(node_path)
|
||||
if n is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
def get_node(self, node_id: str) -> BaseInvocation:
|
||||
"""Gets a node from the graph."""
|
||||
try:
|
||||
return self.nodes[node_id]
|
||||
except KeyError as e:
|
||||
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
|
||||
def get_node(self, node_path: str) -> BaseInvocation:
|
||||
"""Gets a node from the graph using a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
return graph.nodes[node_id]
|
||||
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
|
||||
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
"""Updates a node in the graph."""
|
||||
node = self.nodes[node_id]
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
if new_node.id != node.id and self.has_node(new_node.id):
|
||||
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
|
||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||
|
||||
# Set the new node in the graph
|
||||
self.nodes[new_node.id] = new_node
|
||||
graph.nodes[new_node.id] = new_node
|
||||
if new_node.id != node.id:
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Delete node and all edges
|
||||
self.delete_node(node_id)
|
||||
graph.delete_node(node_path)
|
||||
|
||||
# Create new edges for each input and output
|
||||
for edge in input_edges:
|
||||
self.add_edge(
|
||||
for graph, _, edge in input_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||
)
|
||||
)
|
||||
|
||||
for edge in output_edges:
|
||||
self.add_edge(
|
||||
for graph, _, edge in output_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
|
||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||
destination=edge.destination,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
edges = [e for e in self.edges if e.destination.node_id == node_id]
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
|
||||
if field is None:
|
||||
return edges
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
filtered_edges = [e for e in edges if e.destination.field == field]
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
|
||||
return filtered_edges
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
|
||||
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
|
||||
edges = [e for e in self.edges if e.source.node_id == node_id]
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if field is None:
|
||||
return edges
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
filtered_edges = [e for e in edges if e.source.field == field]
|
||||
return edges
|
||||
|
||||
return filtered_edges
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
def _is_iterator_connection_valid(
|
||||
self,
|
||||
node_id: str,
|
||||
node_path: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -644,12 +718,12 @@ class Graph(BaseModel):
|
||||
|
||||
def _is_collector_connection_valid(
|
||||
self,
|
||||
node_id: str,
|
||||
node_path: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -705,17 +779,27 @@ class Graph(BaseModel):
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
# Add all nodes from this graph except graph/iteration nodes
|
||||
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
|
||||
g.add_nodes_from(
|
||||
[
|
||||
self._get_node_path(n.id, prefix)
|
||||
for n in self.nodes.values()
|
||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||
]
|
||||
)
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
@@ -740,7 +824,9 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The results of node executions", default_factory=dict
|
||||
)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@@ -757,51 +843,27 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
@@ -857,17 +919,17 @@ class GraphExecutionState(BaseModel):
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_id)
|
||||
node = self.graph.get_node(node_path)
|
||||
|
||||
self_iteration_count = -1
|
||||
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
@@ -881,7 +943,7 @@ class GraphExecutionState(BaseModel):
|
||||
return new_nodes
|
||||
|
||||
# Get all input edges
|
||||
input_edges = self.graph._get_input_edges(node_id)
|
||||
input_edges = self.graph._get_input_edges(node_path)
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
@@ -908,10 +970,10 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add to execution graph
|
||||
self.execution_graph.add_node(new_node)
|
||||
self.prepared_source_mapping[new_node.id] = node_id
|
||||
if node_id not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_id] = set()
|
||||
self.source_prepared_mapping[node_id].add(new_node.id)
|
||||
self.prepared_source_mapping[new_node.id] = node_path
|
||||
if node_path not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_path] = set()
|
||||
self.source_prepared_mapping[node_path].add(new_node.id)
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
@@ -1015,13 +1077,13 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _get_iteration_node(
|
||||
self,
|
||||
source_node_id: str,
|
||||
source_node_path: str,
|
||||
graph: nx.DiGraph,
|
||||
execution_graph: nx.DiGraph,
|
||||
prepared_iterator_nodes: list[str],
|
||||
) -> Optional[str]:
|
||||
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_path]
|
||||
if len(prepared_nodes) == 1:
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
@@ -1032,7 +1094,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next(
|
||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||
@@ -1101,19 +1163,19 @@ class GraphExecutionState(BaseModel):
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
self.graph.add_node(node)
|
||||
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_id} has already been prepared or executed and cannot be updated"
|
||||
f"Node {node_path} has already been prepared or executed and cannot be updated"
|
||||
)
|
||||
self.graph.update_node(node_id, new_node)
|
||||
self.graph.update_node(node_path, new_node)
|
||||
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_id} has already been prepared or executed and cannot be deleted"
|
||||
f"Node {node_path} has already been prepared or executed and cannot be deleted"
|
||||
)
|
||||
self.graph.delete_node(node_id)
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
@@ -1128,3 +1190,63 @@ class GraphExecutionState(BaseModel):
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
)
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||
if len(v) != len({i.alias for i in v}):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values.graph
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values.exposed_inputs:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(
|
||||
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
||||
)
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values.exposed_outputs:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(
|
||||
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.model_rebuild(force=True)
|
||||
Graph.model_rebuild(force=True)
|
||||
GraphExecutionState.model_rebuild(force=True)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@@ -13,17 +12,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
@@ -50,102 +48,99 @@ Note: The docstrings are in weird places, but that's where they must be to get I
|
||||
|
||||
@dataclass
|
||||
class InvocationContextData:
|
||||
queue_item: "SessionQueueItem"
|
||||
"""The queue item that is being executed."""
|
||||
invocation: "BaseInvocation"
|
||||
"""The invocation that is being executed."""
|
||||
source_invocation_id: str
|
||||
"""The ID of the invocation from which the currently executing invocation was prepared."""
|
||||
session_id: str
|
||||
"""The session that is being executed."""
|
||||
queue_id: str
|
||||
"""The queue in which the session is being executed."""
|
||||
source_node_id: str
|
||||
"""The ID of the node from which the currently executing invocation was prepared."""
|
||||
queue_item_id: int
|
||||
"""The ID of the queue item that is being executed."""
|
||||
batch_id: str
|
||||
"""The ID of the batch that is being executed."""
|
||||
workflow: Optional[WorkflowWithoutID] = None
|
||||
"""The workflow associated with this queue item, if any."""
|
||||
|
||||
|
||||
class InvocationContextInterface:
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData) -> None:
|
||||
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||
self._services = services
|
||||
self._data = data
|
||||
self._context_data = context_data
|
||||
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""Creates a board.
|
||||
"""
|
||||
Creates a board.
|
||||
|
||||
Args:
|
||||
board_name: The name of the board to create.
|
||||
|
||||
Returns:
|
||||
The created board DTO.
|
||||
:param board_name: The name of the board to create.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""Gets a board DTO.
|
||||
"""
|
||||
Gets a board DTO.
|
||||
|
||||
Args:
|
||||
board_id: The ID of the board to get.
|
||||
|
||||
Returns:
|
||||
The board DTO.
|
||||
:param board_id: The ID of the board to get.
|
||||
"""
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""Gets all boards.
|
||||
|
||||
Returns:
|
||||
A list of all boards.
|
||||
"""
|
||||
Gets all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""Adds an image to a board.
|
||||
"""
|
||||
Adds an image to a board.
|
||||
|
||||
Args:
|
||||
board_id: The ID of the board to add the image to.
|
||||
image_name: The name of the image to add to the board.
|
||||
:param board_id: The ID of the board to add the image to.
|
||||
:param image_name: The name of the image to add to the board.
|
||||
"""
|
||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||
|
||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
"""Gets all image names for a board.
|
||||
"""
|
||||
Gets all image names for a board.
|
||||
|
||||
Args:
|
||||
board_id: The ID of the board to get the image names for.
|
||||
|
||||
Returns:
|
||||
A list of all image names for the board.
|
||||
:param board_id: The ID of the board to get the image names for.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
def debug(self, message: str) -> None:
|
||||
"""Logs a debug message.
|
||||
"""
|
||||
Logs a debug message.
|
||||
|
||||
Args:
|
||||
message: The message to log.
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Logs an info message.
|
||||
"""
|
||||
Logs an info message.
|
||||
|
||||
Args:
|
||||
message: The message to log.
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Logs a warning message.
|
||||
"""
|
||||
Logs a warning message.
|
||||
|
||||
Args:
|
||||
message: The message to log.
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Logs an error message.
|
||||
"""
|
||||
Logs an error message.
|
||||
|
||||
Args:
|
||||
message: The message to log.
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.error(message)
|
||||
|
||||
@@ -158,60 +153,54 @@ class ImagesInterface(InvocationContextInterface):
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> ImageDTO:
|
||||
"""Saves an image, returning its DTO.
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
Args:
|
||||
image: The image to save, as a PIL image.
|
||||
board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||
you want to override or provide a board manually!**
|
||||
image_category: The category of the image. Only the GENERAL category is added \
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
metadata: The metadata to save with the image, if it should have any. If the \
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
|
||||
Returns:
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
metadata_ = metadata
|
||||
elif isinstance(self._data.invocation, WithMetadata):
|
||||
metadata_ = self._data.invocation.metadata
|
||||
elif isinstance(self._context_data.invocation, WithMetadata):
|
||||
metadata_ = self._context_data.invocation.metadata
|
||||
|
||||
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||
board_id_ = None
|
||||
if board_id:
|
||||
board_id_ = board_id
|
||||
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
|
||||
board_id_ = self._data.invocation.board.board_id
|
||||
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
|
||||
board_id_ = self._context_data.invocation.board.board_id
|
||||
|
||||
return self._services.images.create(
|
||||
image=image,
|
||||
is_intermediate=self._data.invocation.is_intermediate,
|
||||
is_intermediate=self._context_data.invocation.is_intermediate,
|
||||
image_category=image_category,
|
||||
board_id=board_id_,
|
||||
metadata=metadata_,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
workflow=self._data.queue_item.workflow,
|
||||
session_id=self._data.queue_item.session_id,
|
||||
node_id=self._data.invocation.id,
|
||||
workflow=self._context_data.workflow,
|
||||
session_id=self._context_data.session_id,
|
||||
node_id=self._context_data.invocation.id,
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
"""Gets an image as a PIL Image object.
|
||||
"""
|
||||
Gets an image as a PIL Image object.
|
||||
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
|
||||
Returns:
|
||||
The image as a PIL Image object.
|
||||
:param image_name: The name of the image to get.
|
||||
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
"""
|
||||
image = self._services.images.get_pil_image(image_name)
|
||||
if mode and mode != image.mode:
|
||||
@@ -224,221 +213,163 @@ class ImagesInterface(InvocationContextInterface):
|
||||
return image
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata, if it has any.
|
||||
"""
|
||||
Gets an image's metadata, if it has any.
|
||||
|
||||
Args:
|
||||
image_name: The name of the image to get the metadata for.
|
||||
|
||||
Returns:
|
||||
The image's metadata, if it has any.
|
||||
:param image_name: The name of the image to get the metadata for.
|
||||
"""
|
||||
return self._services.images.get_metadata(image_name)
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image as an ImageDTO object.
|
||||
"""
|
||||
Gets an image as an ImageDTO object.
|
||||
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
|
||||
Returns:
|
||||
The image as an ImageDTO object.
|
||||
:param image_name: The name of the image to get.
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""Saves a tensor, returning its name.
|
||||
"""
|
||||
Saves a tensor, returning its name.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved tensor.
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""Loads a tensor by name.
|
||||
"""
|
||||
Loads a tensor by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tensor to load.
|
||||
|
||||
Returns:
|
||||
The loaded tensor.
|
||||
:param name: The name of the tensor to load.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""Saves a conditioning data object, returning its name.
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
|
||||
Args:
|
||||
conditioning_data: The conditioning data to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved conditioning data.
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.save(obj=conditioning_data)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""Loads conditioning data by name.
|
||||
"""
|
||||
Loads conditioning data by name.
|
||||
|
||||
Args:
|
||||
name: The name of the conditioning data to load.
|
||||
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
:param name: The name of the conditioning data to load.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
Checks if a model exists.
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Loads a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
:param key: The key of the model.
|
||||
:param submodel_type: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._context_data
|
||||
)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model by its attributes.
|
||||
|
||||
Args:
|
||||
name: Name of the model.
|
||||
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
|
||||
models have submodels.
|
||||
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
Loads a model by its attributes.
|
||||
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
context_data=self._context_data,
|
||||
)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Gets a model's metadata, if it has any.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
"""
|
||||
Searches for models by path.
|
||||
|
||||
Args:
|
||||
path: The path to search for.
|
||||
|
||||
Returns:
|
||||
A list of models that match the path.
|
||||
:param path: The path to search for.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
base: Optional[BaseModelType] = None,
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""Searches for models by attributes.
|
||||
"""
|
||||
Searches for models by attributes.
|
||||
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
|
||||
|
||||
Returns:
|
||||
A list of models that match the attributes.
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
model_format=format,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config.
|
||||
|
||||
Returns:
|
||||
The app's config.
|
||||
"""
|
||||
"""Gets the app's config."""
|
||||
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
|
||||
) -> None:
|
||||
super().__init__(services, data)
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
@@ -446,32 +377,27 @@ class UtilInterface(InvocationContextInterface):
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
base_model: The base model for the current denoising step.
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
# The step callback needs access to the events and the invocation queue services, but this
|
||||
# represents a dangerous level of access.
|
||||
#
|
||||
# We wrap the step callback so that nodes do not have direct access to these services.
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
context_data=self._data,
|
||||
context_data=self._context_data,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=base_model,
|
||||
invocation_queue=self._services.queue,
|
||||
events=self._services.events,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
Attributes:
|
||||
images (ImagesInterface): Methods to save, get and update images and their metadata.
|
||||
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
|
||||
conditioning (ConditioningInterface): Methods to save and get conditioning data.
|
||||
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
|
||||
logger (LoggerInterface): The app logger.
|
||||
config (ConfigInterface): The app config.
|
||||
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
|
||||
boards (BoardsInterface): Methods to interact with boards.
|
||||
"""
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -484,54 +410,50 @@ class InvocationContext:
|
||||
config: ConfigInterface,
|
||||
util: UtilInterface,
|
||||
boards: BoardsInterface,
|
||||
data: InvocationContextData,
|
||||
context_data: InvocationContextData,
|
||||
services: InvocationServices,
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Methods to save, get and update images and their metadata."""
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
self.tensors = tensors
|
||||
"""Methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"""Methods to save and get conditioning data."""
|
||||
"""Provides methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
"""Methods to check if a model exists, get a model, and get a model's info."""
|
||||
"""Provides methods to check if a model exists, get a model, and get a model's info."""
|
||||
self.logger = logger
|
||||
"""The app logger."""
|
||||
"""Provides access to the app logger."""
|
||||
self.config = config
|
||||
"""The app config."""
|
||||
"""Provides access to the app's config."""
|
||||
self.util = util
|
||||
"""Utility methods, including a method to check if an invocation was canceled and step callbacks."""
|
||||
"""Provides utility methods."""
|
||||
self.boards = boards
|
||||
"""Methods to interact with boards."""
|
||||
self._data = data
|
||||
"""An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning."""
|
||||
"""Provides methods to interact with boards."""
|
||||
self._data = context_data
|
||||
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
||||
self._services = services
|
||||
"""An internal API providing access to all application services. You probably shouldn't use this. It may change without warning."""
|
||||
"""Provides access to the full application services. This is an internal API and may change without warning."""
|
||||
|
||||
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
data: InvocationContextData,
|
||||
cancel_event: threading.Event,
|
||||
context_data: InvocationContextData,
|
||||
) -> InvocationContext:
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
"""
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
|
||||
Args:
|
||||
services: The invocation services to wrap.
|
||||
data: The invocation context data.
|
||||
|
||||
Returns:
|
||||
The invocation context.
|
||||
:param invocation_services: The invocation services to wrap.
|
||||
:param invocation_context_data: The invocation context data.
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
logger = LoggerInterface(services=services, context_data=context_data)
|
||||
images = ImagesInterface(services=services, context_data=context_data)
|
||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||
models = ModelsInterface(services=services, context_data=context_data)
|
||||
config = ConfigInterface(services=services, context_data=context_data)
|
||||
util = UtilInterface(services=services, context_data=context_data)
|
||||
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||
boards = BoardsInterface(services=services, context_data=context_data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
images=images,
|
||||
@@ -539,7 +461,7 @@ def build_invocation_context(
|
||||
config=config,
|
||||
tensors=tensors,
|
||||
models=models,
|
||||
data=data,
|
||||
context_data=context_data,
|
||||
util=util,
|
||||
conditioning=conditioning,
|
||||
services=services,
|
||||
|
||||
@@ -9,7 +9,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -36,7 +35,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -21,7 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@@ -72,27 +73,19 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@@ -102,7 +95,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@@ -150,14 +143,9 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.hash,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
||||
@@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
||||
@@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
||||
@@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
self._base_url_v2 = base_url_v2
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
@@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
||||
55
invokeai/app/util/metadata.py
Normal file
55
invokeai/app/util/metadata.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
@@ -11,6 +11,7 @@ from ...backend.util.util import image_to_dataURL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
|
||||
@@ -33,10 +34,10 @@ def stable_diffusion_step_callback(
|
||||
context_data: "InvocationContextData",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
invocation_queue: "InvocationQueueABC",
|
||||
events: "EventServiceBase",
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
if invocation_queue.is_canceled(context_data.session_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
@@ -114,12 +115,12 @@ def stable_diffusion_step_callback(
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
source_node_id=context_data.source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
|
||||
@@ -1,47 +1,8 @@
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
|
||||
ti_triggers: List[str] = []
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
||||
ti_triggers = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
ti_triggers.append(str(trigger))
|
||||
ti_triggers.append(trigger)
|
||||
return ti_triggers
|
||||
|
||||
|
||||
def generate_ti_list(
|
||||
prompt: str, base: BaseModelType, context: InvocationContext
|
||||
) -> List[Tuple[str, TextualInversionModelRaw]]:
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
pass
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
except AssertionError:
|
||||
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
|
||||
except Exception:
|
||||
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
|
||||
return ti_list
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
|
||||
@@ -3,7 +3,6 @@ This module defines a singleton object, "invisible_watermark" that
|
||||
wraps the invisible watermark model. It respects the global "invisible_watermark"
|
||||
configuration variable, that allows the watermarking to be supressed.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
@@ -4,7 +4,6 @@ wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user