Compare commits

..

51 Commits

Author SHA1 Message Date
Ryan Dick
ff950bc5cd Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding). 2024-03-07 14:30:51 -05:00
Ryan Dick
969982b789 Fixup some details of densediffusion for testing. 2024-03-06 19:03:26 -05:00
Ryan Dick
b8cbff828b wip 2024-03-06 10:52:35 -05:00
Ryan Dick
d3a40c5b2b Rough hacky implementation of DenseDiffusion. 2024-03-05 18:10:01 -05:00
Ryan Dick
57266d36a2 Remove dispatch_progress() function that was added aciidentally during conflict resolution. 2024-03-05 15:31:54 -05:00
Ryan Dick
41e1a9f202 Use the correct device / dtype for RegionalPromptData calculations. 2024-03-05 15:19:58 -05:00
Ryan Dick
bcfb43e5f0 (minor) Remove commented code. 2024-03-05 09:12:17 -05:00
Ryan Dick
a665f20fb5 Add positive_self_attn_mask_score and self_attn_adjustment_end_step_percent to the prompt nodes. 2024-03-04 15:34:26 -05:00
Ryan Dick
d313e5eb70 Remove AddConditioningMaskInvocaton. 2024-03-04 14:11:38 -05:00
Ryan Dick
271f8f2414 Merge branch 'main' into ryan/regional-conditioning-tuning 2024-03-04 10:52:24 -05:00
Mary Hipp Rogers
8b34f5298c Default model settings (#5850)
* UI in MM to create trigger phrases

* add scheduler and vaePrecision to config

* UI for configuring default settings for models'

* hook MM default model settings up to API

* add button to set default settings in parameters

* pull out trigger phrases

* back-end for default settings

* lint

* remove log;
gi

* ruff

* ruff format

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-04 09:39:03 -05:00
Brandon Rising
893bcd16fc Next: Allow in place local installs of models 2024-03-04 23:11:41 +11:00
Ryan Dick
f6028a4c61 Log a stack trace for invocation errors. 2024-03-04 23:01:56 +11:00
Hosted Weblate
264aee3ffa translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
Riccardo Giovanetti
4deb60f365 translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1442 of 1470 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
B N
f2d5fb176f translationBot(ui): update translation (German)
Currently translated at 80.4% (1183 of 1470 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
Mary Hipp
94005b5501 add button to navigate to model manager if tab is enabled 2024-03-03 19:50:50 -05:00
Mary Hipp
02dc1a8780 consolidate tabs for main model and concepts in generation panel 2024-03-03 19:50:50 -05:00
Wubbbi
ef958568ac Update Transformers 4.37.2 -> 4.38.2 2024-03-03 19:41:33 -05:00
dunkeroni
48e323d887 docs: added both create mask nodes to defaultNodes 2024-03-03 12:58:47 -05:00
dunkeroni
735857479d fix(canvas): use corrected mask for pasteback 2024-03-03 12:58:47 -05:00
Ryan Dick
5fad379192 Add ability to control regional prompt region weights. 2024-03-03 12:55:07 -05:00
psychedelicious
2f372d9b18 tests(mm): update tests to reflect using UUID for key 2024-03-03 14:32:14 +11:00
psychedelicious
554d175792 feat(mm): improved model hash class
- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib)
- Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block)
- Add support for custom file filters
- Update docstrings
- Update tests
2024-03-03 14:32:14 +11:00
psychedelicious
ae99428883 fix(mm): use UUIDv4 for key
This changes the functionality of this PR to only use the updated hashing for model hashes with a UUID for the key.
2024-03-03 14:32:14 +11:00
psychedelicious
863ce00712 tests(mm): add tests for ModelHash 2024-03-03 14:32:14 +11:00
psychedelicious
86982f3059 feat(mm): make ModelHash instantiatable, taking an algorithm as arg 2024-03-03 14:32:14 +11:00
psychedelicious
ec8ed530a7 feat(mm): modularize ModelHash to facilitate testing 2024-03-03 14:32:14 +11:00
psychedelicious
982076d7d7 feat(mm): add hashing algos to ModelHash
- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
2024-03-03 14:32:14 +11:00
psychedelicious
2e4672f931 feat(mm): make hash.py a script for testing 2024-03-03 14:32:14 +11:00
psychedelicious
908e915a71 feat(mm): use blake3 for hashing 2024-03-03 14:32:14 +11:00
Lincoln Stein
a72056e0df make model key assignment deterministic
- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
2024-03-03 14:32:14 +11:00
Ryan Dick
ad18429fe3 Very experimentation with various regional prompting tuning params. 2024-03-02 17:43:21 -05:00
Ryan Dick
d8d7ddf43a Remove attention map saving (#5845)
## What type of PR is this? (check all applicable)

- [x] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because

## Description

Attention map saving was a feature that existed a long time ago in
Invoke (>1 year ago). This PR strips out a bunch of dead code that still
remains from that feature and is polluting our diffusion implementation.

This change should not have any functional effect on the app.

## QA Instructions, Screenshots, Recordings

I did a quick smoke test of SD and SDXL image generation. All of the
deleted code was unused, so the risk should be relatively low.

## Merge Plan

- [x] Change target branch to `main` before merging.

## Added/updated tests?

- [ ] Yes
- [x] No: This PR just deletes a bunch of unused code.
2024-03-02 11:15:25 -05:00
Ryan Dick
cc45007dc4 Remove unused code for attention map saving. 2024-03-02 08:25:41 -05:00
Ryan Dick
73bec56c59 Delete unused functions from shared_invokeai_diffusion.py. 2024-03-02 08:25:41 -05:00
psychedelicious
f8b54930f0 docs: update RELEASE.md 2024-03-02 08:23:49 -05:00
psychedelicious
51cc9f9466 ci: add comments to workflows 2024-03-02 08:23:49 -05:00
psychedelicious
d2ad465e96 ci: rename test matrix
Now python version: platform, e.g. `py3.10: linux-cpu`

This displays better in GH actions.
2024-03-02 08:23:49 -05:00
psychedelicious
09037b7cd4 ci: add conditionals for jobs based on dispatch/call 2024-03-02 08:23:49 -05:00
psychedelicious
b2a850b5ea ci: rename jobs, remove extraneous needs in release 2024-03-02 08:23:49 -05:00
psychedelicious
3ba5c2b0b4 ci: split build job 2024-03-02 08:23:49 -05:00
psychedelicious
06fc6ccfe5 ci: workflow & job names 2024-03-02 08:23:49 -05:00
psychedelicious
0c6b0cfdab ci: tidy pr labeler labels 2024-03-02 08:23:49 -05:00
psychedelicious
eef3373799 ci: fix workflows
Do not split up "on change" and "do the thing". Less convoluted, no catch-22 with required checks for PRs.
2024-03-02 08:23:49 -05:00
Ryan Dick
942efa011e Implement (very slow) self-attention regional masking. 2024-03-01 18:43:32 -05:00
Ryan Dick
6935830f99 Remove unused constructor declared with typo in name: __int__. 2024-03-01 15:12:03 -05:00
Ryan Dick
7651eeea8d Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. 2024-03-01 15:12:03 -05:00
Ryan Dick
204e7d383b Remove outdated comments related to T2I-Adapters and ControlNets. 2024-03-01 15:12:03 -05:00
Ryan Dick
9bc4e7a593 Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. 2024-03-01 15:12:03 -05:00
Ryan Dick
ad96857e0f Fix avoid storing extra conditioning info in two places. 2024-03-01 15:12:03 -05:00
79 changed files with 2237 additions and 893 deletions

View File

@@ -1,33 +1,33 @@
name: Install frontend dependencies
name: install frontend dependencies
description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: Setup Node 18
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
- name: setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
run_install: false
- name: Get pnpm store directory
- name: get pnpm store directory
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- uses: actions/cache@v3
name: Setup pnpm cache
- name: setup cache
uses: actions/cache@v4
with:
path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: Install frontend dependencies
- name: install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
shell: bash
working-directory: invokeai/frontend/web

View File

@@ -1,11 +0,0 @@
name: Install python dependencies
description: Install python dependencies with pip, with caching
runs:
using: 'composite'
steps:
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml

28
.github/pr_labels.yml vendored
View File

@@ -1,59 +1,59 @@
Root:
root:
- changed-files:
- any-glob-to-any-file: '*'
PythonDeps:
python-deps:
- changed-files:
- any-glob-to-any-file: 'pyproject.toml'
Python:
python:
- changed-files:
- all-globs-to-any-file:
- 'invokeai/**'
- '!invokeai/frontend/web/**'
PythonTests:
python-tests:
- changed-files:
- any-glob-to-any-file: 'tests/**'
CICD:
ci-cd:
- changed-files:
- any-glob-to-any-file: .github/**
Docker:
docker:
- changed-files:
- any-glob-to-any-file: docker/**
Installer:
installer:
- changed-files:
- any-glob-to-any-file: installer/**
Documentation:
docs:
- changed-files:
- any-glob-to-any-file: docs/**
Invocations:
invocations:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**'
Backend:
backend:
- changed-files:
- any-glob-to-any-file: 'invokeai/backend/**'
Api:
api:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**'
Services:
services:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**'
FrontendDeps:
frontend-deps:
- changed-files:
- any-glob-to-any-file:
- '**/*/package.json'
- '**/*/pnpm-lock.yaml'
Frontend:
frontend:
- changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**'

45
.github/workflows/build-installer.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
# Builds and uploads the installer and python build artifacts.
name: build installer
on:
workflow_dispatch:
workflow_call:
jobs:
build-installer:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <2 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install pypa/build
run: pip install --upgrade build
- name: setup frontend
uses: ./.github/actions/install-frontend-deps
- name: create installer
id: create_installer
run: ./create_installer.sh
working-directory: installer
- name: upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}

View File

@@ -1,43 +0,0 @@
# This workflow runs the frontend code quality checks.
#
# It may be triggered via dispatch, or by another workflow.
name: 'Check: frontend'
on:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
check-frontend:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: Set up frontend
uses: ./.github/actions/install-frontend-deps
- name: Run tsc check
run: 'pnpm run lint:tsc'
shell: bash
- name: Run dpdm check
run: 'pnpm run lint:dpdm'
shell: bash
- name: Run eslint check
run: 'pnpm run lint:eslint'
shell: bash
- name: Run prettier check
run: 'pnpm run lint:prettier'
shell: bash
- name: Run knip check
run: 'pnpm run lint:knip'
shell: bash

View File

@@ -1,72 +0,0 @@
# This workflow runs pytest on the codebase in a matrix of platforms.
#
# It may be triggered via dispatch, or by another workflow.
name: 'Check: pytest'
on:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
timeout-minutes: 30 # expected run time: <10 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- uses: actions/checkout@v4
- name: set test prompt to main branch validation
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install invokeai
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run pytest
id: run-pytest
run: pytest

View File

@@ -1,33 +0,0 @@
# This workflow runs the python code quality checks.
#
# It may be triggered via dispatch, or by another workflow.
#
# TODO: Add mypy or pyright to the checks.
name: 'Check: python'
on:
workflow_dispatch:
workflow_call:
jobs:
check-backend:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- uses: actions/checkout@v4
- name: Install python dependencies
uses: ./.github/actions/install-python-deps
- name: Install ruff
run: pip install ruff
shell: bash
- name: Ruff check
run: ruff check --output-format=github .
shell: bash
- name: Ruff format
run: ruff format --check .
shell: bash

68
.github/workflows/frontend-checks.yml vendored Normal file
View File

@@ -0,0 +1,68 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# When manually triggered or when called from another workflow, always runs the checks.
name: 'frontend checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-checks:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:knip'
shell: bash

48
.github/workflows/frontend-tests.yml vendored Normal file
View File

@@ -0,0 +1,48 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'frontend tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-tests:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm test:no-watch'
shell: bash

View File

@@ -1,6 +1,6 @@
name: "Pull Request Labeler"
name: 'label PRs'
on:
- pull_request_target
- pull_request_target
jobs:
labeler:
@@ -9,8 +9,10 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Checkout
- name: checkout
uses: actions/checkout@v4
- uses: actions/labeler@v5
- name: label PRs
uses: actions/labeler@v5
with:
configuration-path: .github/pr_labels.yml
configuration-path: .github/pr_labels.yml

View File

@@ -21,18 +21,29 @@ jobs:
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
- name: set cache id
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- name: use cache
uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: python -m pip install ".[docs]"
- run: mkdocs gh-deploy --force
- name: install dependencies
run: python -m pip install ".[docs]"
- name: build & deploy
run: mkdocs gh-deploy --force

View File

@@ -1,39 +0,0 @@
# This workflow runs of `check-frontend.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run check-frontend'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-frontend-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
frontend_any_changed: ${{ steps.changed-files.outputs.frontend_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed frontend files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
run-check-frontend:
needs: check-changed-frontend-files
if: ${{ needs.check-changed-frontend-files.outputs.frontend_any_changed == 'true' }}
uses: ./.github/workflows/check-frontend.yml

View File

@@ -1,42 +0,0 @@
# This workflow runs of `check-python.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run check-python'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-python-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
run-check-python:
needs: check-changed-python-files
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
uses: ./.github/workflows/check-python.yml

View File

@@ -1,42 +0,0 @@
# This workflow runs of `check-pytest.yml` on push or pull request.
#
# The actual checks are in a separate workflow to support simpler workflow
# composition without awkward or complicated conditionals.
name: 'On change: run pytest'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
jobs:
check-changed-python-files:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
python_any_changed: ${{ steps.changed-files.outputs.python_any_changed }}
steps:
- uses: actions/checkout@v4
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
run-pytest:
needs: check-changed-python-files
if: ${{ needs.check-changed-python-files.outputs.python_any_changed == 'true' }}
uses: ./.github/workflows/check-pytest.yml

64
.github/workflows/python-checks.yml vendored Normal file
View File

@@ -0,0 +1,64 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# When manually triggered or called from another workflow, always runs the tests.
#
# TODO: Add mypy or pyright to the checks.
name: 'python checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
jobs:
python-checks:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff format --check .
shell: bash

94
.github/workflows/python-tests.yml vendored Normal file
View File

@@ -0,0 +1,94 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'python tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
runs-on: ${{ matrix.os }}
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pytest

View File

@@ -1,103 +1,96 @@
name: Release
# Main release workflow. Triggered on tag push or manual trigger.
#
# - Runs all code checks and tests
# - Verifies the app version matches the tag version.
# - Builds the installer and build, uploading them as artifacts.
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
#
# See docs/RELEASE.md for more information on the release process.
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
skip_code_checks:
description: 'Skip code checks'
required: true
default: true
type: boolean
jobs:
check-version:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: checkout
uses: actions/checkout@v4
- uses: samuelcolvin/check-python-version@v4
- name: check python version
uses: samuelcolvin/check-python-version@v4
id: check-python-version
with:
version_file_path: invokeai/version/invokeai_version.py
check-frontend:
if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-frontend.yml
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
check-python:
if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-python.yml
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
check-pytest:
if: github.event.inputs.skip_code_checks != 'true'
uses: ./.github/workflows/check-pytest.yml
python-checks:
uses: ./.github/workflows/python-checks.yml
python-tests:
uses: ./.github/workflows/python-tests.yml
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install python dependencies
uses: ./.github/actions/install-python-deps
- name: Install pypa/build
run: pip install --upgrade build
- name: Setup frontend
uses: ./.github/actions/install-frontend-deps
- name: Run create_installer.sh
id: create_installer
run: ./create_installer.sh --skip_frontend_checks
working-directory: installer
- name: Upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: Upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
uses: ./.github/workflows/build-installer.yml
publish-testpypi:
runs-on: ubuntu-latest
needs: [check-version, check-frontend, check-python, check-pytest, build]
if: github.event_name != 'workflow_dispatch'
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: testpypi
url: https://test.pypi.org/p/invokeai
steps:
- name: Download distribution from build job
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: Publish distribution to TestPyPI
- name: publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
publish-pypi:
runs-on: ubuntu-latest
needs: [check-version, check-frontend, check-python, check-pytest, build]
if: github.event_name != 'workflow_dispatch'
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: pypi
url: https://pypi.org/p/invokeai
steps:
- name: Download distribution from build job
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: Publish distribution to PyPI
- name: publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -23,13 +23,13 @@ It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if
Run `make tag-release` to tag the current commit and kick off the workflow.
The release may also be run [manually].
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs run if the 5 concurrent jobs all succeed and if/when the publish jobs are approved.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
@@ -43,17 +43,16 @@ This job uses [samuelcolvin/check-python-version].
#### Check and Test Jobs
This is our test suite.
- **`check-pytest`**: runs `pytest` on matrix of platforms
- **`check-python`**: runs `ruff` (format and lint)
- **`check-frontend`**: runs `prettier` (format), `eslint` (lint), `madge` (circular refs) and `tsc` (static type check)
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build` Job
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
@@ -62,7 +61,7 @@ This sets up both python and frontend dependencies and builds the python package
#### Sanity Check & Smoke Test
At this point, the release workflow pauses (the remaining jobs all require approval).
At this point, the release workflow pauses as the remaining publish jobs require approval.
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
@@ -70,7 +69,7 @@ A maintainer should go to the **Summary** tab of the workflow, download the inst
#### PyPI Publish Jobs
The publish jobs will skip if any of the previous jobs skip or fail.
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
@@ -119,13 +118,17 @@ Once the release is published to PyPI, it's time to publish the GitHub release.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manually Running the Release Workflow
## Manual Build
The release workflow can be run manually. This is useful to get an installer build and test it out without needing to push a tag.
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
When run this way, you'll see **Skip code checks** checkbox. This allows the workflow to run without the time-consuming 3 code quality check jobs.
No checks are run, it just builds.
The publish jobs will skip if the workflow was run manually.
## Manual Release
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
@@ -136,4 +139,4 @@ The publish jobs will skip if the workflow was run manually.
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
[manually]: #manually-running-the-release-workflow
[manually]: #manual-release

View File

@@ -19,6 +19,8 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@@ -14,6 +14,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@@ -32,6 +33,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@@ -243,6 +245,47 @@ async def get_model_metadata(
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes"),
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.default_settings:
original_metadata.default_settings = changes.default_settings
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
@@ -451,6 +494,7 @@ async def add_model_record(
)
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@@ -493,6 +537,7 @@ async def install_model(
source=source,
config=config,
access_token=access_token,
inplace=bool(inplace),
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:

View File

@@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@@ -36,7 +44,7 @@ from .model import ClipField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.0.1",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
class SDXLPromptInvocationBase:
@@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.1",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
@@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
@invocation(
@@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
@invocation_output("clip_skip_output")

View File

@@ -6,25 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
)
from invokeai.app.invocations.fields import InputField, WithMetadata
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
@invocation(
"add_conditioning_mask",
title="Add Conditioning Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
def invoke(self, context: InvocationContext) -> ConditioningOutput:
self.conditioning.mask = self.mask
return ConditioningOutput(conditioning=self.conditioning)
from invokeai.app.invocations.primitives import MaskField, MaskOutput
@invocation(

View File

@@ -236,6 +236,7 @@ class ConditioningField(BaseModel):
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
mask_weight: float = Field(description="")
class MetadataField(RootModel):

View File

@@ -51,7 +51,6 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
@@ -181,6 +180,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
@@ -201,38 +210,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1
blur_tensor = (blur_tensor - 0.5) * 2
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=None,
gradient=True,
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
@@ -359,35 +372,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.model_dump(),
source_node_id=source_node_id,
base_model=base_model,
)
def _get_text_embeddings_and_masks(
self,
cond_field: Union[ConditioningField, list[ConditioningField]],
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
# Normalize cond_field to a list.
cond_list = cond_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
@@ -427,6 +419,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
conditioning_fields: list[ConditioningField],
latent_height: int,
latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
@@ -447,7 +440,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
embedding_ranges = []
extra_conditioning = None
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if (
text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control
@@ -472,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
# embedding_ranges.append(
# Range(
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
# )
# )
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
# in the count.
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
start=cur_text_embedding_len + 1,
end=cur_text_embedding_len
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
@@ -486,7 +489,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
mask_weights=[x.mask_weight for x in conditioning_fields],
)
if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError(
@@ -509,25 +516,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data(
self,
context: InvocationContext,
unet,
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.positive_conditioning, context, unet.device, unet.dtype
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.negative_conditioning, context, unet.device, unet.dtype
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
conditioning_fields=cond_list,
latent_height=latent_height,
latent_width=latent_width,
)
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
conditioning_fields=uncond_list,
latent_height=latent_height,
latent_width=latent_width,
)

View File

@@ -427,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "ConditioningOutput":
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput):

View File

@@ -7,7 +7,6 @@ import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
@@ -21,6 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -150,7 +150,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", self._create_key())
config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = self._probe_model(Path(model_path), config)
@@ -178,13 +178,14 @@ class ModelInstallService(ModelInstallServiceBase):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: bool = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
@@ -526,16 +527,17 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
# Note that we may be passed a pre-populated AnyModelConfig object,
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", self._create_key())
config["key"] = config.get("key", uuid_string())
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):

View File

@@ -4,9 +4,25 @@ Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""
class ModelMetadataStoreBase(ABC):

View File

@@ -179,44 +179,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(

View File

@@ -3,7 +3,6 @@
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
@@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
@@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
@@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from typing import Callable, Literal, Optional, Union
from imohash import hashfile
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
self._file_filter = file_filter or self._default_file_filter
def hash(self, model_path: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
:param model_location: Path to the model file
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
model_path = Path(model_path)
if model_path.is_file():
return self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
def _hash_dir(self, dir: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest.
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()
Args:
dir: Path to the directory
Returns:
str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Args:
file_path: Path to the file
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@@ -25,6 +25,7 @@ from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager import ModelRepoVariant
from ..util import select_hf_files
@@ -68,12 +69,24 @@ class RemoteModelFile(BaseModel):
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="default settings for this model", default=None
)
class BaseMetadata(ModelMetadataBase):

View File

@@ -21,7 +21,7 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
hash = ModelHash().hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()

View File

@@ -70,7 +70,12 @@ class Range:
class TextConditioningRegions:
def __init__(self, masks: torch.Tensor, ranges: list[Range]):
def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
mask_weights: list[float],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
@@ -80,7 +85,9 @@ class TextConditioningRegions:
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
self.mask_weights = mask_weights
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
class TextConditioningData:

View File

@@ -1,3 +1,4 @@
import math
from typing import Optional
import torch
@@ -58,6 +59,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale: float = 1.0,
# For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[float] = None,
# For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor:
@@ -92,28 +94,25 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# End unmodified block from AttnProcessor2_0.
# Handle regional prompt attention masks.
if is_cross_attention and regional_prompt_data is not None:
if regional_prompt_data is not None:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len,
percent_through=percent_through,
)
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -137,6 +136,40 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None and percent_through < 0.3:
# Don't apply to uncond????
prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
prompt_region_attention_mask = prompt_region_attention_mask.view(
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
prompt_region_attention_mask = attn_mask_weight * (
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
else:
pass
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(

View File

@@ -7,58 +7,56 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
class RegionalPromptData:
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
self._attn_masks_by_seq_len = attn_masks_by_seq_len
@classmethod
def from_regions(
cls,
def __init__(
self,
regions: list[TextConditioningRegions],
key_seq_len: int,
# TODO(ryand): Pass in a list of downscale factors?
device: torch.device,
dtype: torch.dtype,
max_downscale_factor: int = 8,
):
"""Construct a `RegionalPromptData` object.
"""Initialize a `RegionalPromptData` object.
Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
explicitly to be sure.
device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x.
"""
attn_masks_by_seq_len = {}
self._regions = regions
self._device = device
self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = 0.0
self._size_weight = 1.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors."""
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# length of s.
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_sample_regions in regions:
batch_attn_masks_by_seq_len.append({})
batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
_, num_prompts, h, w = batch_masks.shape
b, _num_prompts, h, w = batch_sample_masks.shape
assert b == 1
query_seq_len = h * w
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# `encoder_hidden_states`.
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
# TODO(ryand): What device / dtype should this be?
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
:, prompt_idx, :, :
]
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
@@ -66,23 +64,17 @@ class RegionalPromptData:
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# potentially use a weighted mask rather than a binary mask.
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
attn_masks_by_seq_len[query_seq_len] = torch.cat(
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
)
return batch_sample_masks_by_seq_len
return cls(attn_masks_by_seq_len)
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
it changes at each downscaling level in the model.
key_seq_len is the length of the expected prompt embeddings.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
Returns:
torch.Tensor: The masks.
@@ -90,4 +82,83 @@ class RegionalPromptData:
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
return self._attn_masks_by_seq_len[query_seq_len]
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# size = size.to(dtype=batch_sample_query_scores.dtype)
# batch_sample_query_mask = batch_sample_query_scores > 0.5
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
mask_weight + self._size_weight * (1 - size)
)
return attn_mask
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
Returns:
torch.Tensor: The masks.
shape: (batch_size, query_seq_len, query_seq_len).
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts):
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
size = prompt_query_mask.sum() / prompt_query_mask.numel()
size = size.to(dtype=prompt_query_mask.dtype)
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask.
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
attn_mask[batch_idx, :, :] = torch.maximum(
attn_mask[batch_idx, :, :],
prompt_query_mask.unsqueeze(0)
* prompt_query_mask.unsqueeze(1)
* (mask_weight + self._size_weight * (1 - size)),
)
# if attn_mask[batch_idx].max() < 0.01:
# attn_mask[batch_idx, ...] = 1.0
# attn_mask[attn_mask > 0.5] = 1.0
# attn_mask[attn_mask <= 0.5] = 0.0
# attn_mask_min = attn_mask[batch_idx].min()
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
# if abs(attn_mask_min) > 0.0001:
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
return attn_mask

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import time
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
@@ -200,9 +201,9 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
@@ -219,6 +220,7 @@ class InvokeAIDiffuserComponent:
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@@ -232,6 +234,7 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
percent_through=percent_through,
ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@@ -293,6 +296,7 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@@ -326,8 +330,8 @@ class InvokeAIDiffuserComponent:
)
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
@@ -342,13 +346,15 @@ class InvokeAIDiffuserComponent:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
mask_weights=[0.0],
)
regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=regions, key_seq_len=key_seq_len
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
time.sleep(1.0)
both_results = self.model_forward_callback(
x_twice,
@@ -371,6 +377,7 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@@ -441,10 +448,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
@@ -487,10 +494,10 @@ class InvokeAIDiffuserComponent:
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(

View File

@@ -134,8 +134,6 @@
"loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade",
"preparingDownload": "bereite Download vor",
"preparingDownloadFailed": "Problem beim Download vorbereiten",
"deleteImage": "Lösche Bild",
"copy": "Kopieren",
"download": "Runterladen",
@@ -967,7 +965,7 @@
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
"pruneFailed": "Problem beim leeren der Warteschlange",
"pauseTooltip": "Prozess anhalten",
"back": "Hinten",
"back": "Ende",
"resumeSucceeded": "Prozess wird fortgesetzt",
"resumeTooltip": "Prozess wieder aufnehmen",
"time": "Zeit",

View File

@@ -78,6 +78,7 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"add": "Add",
"advanced": "Advanced",
"advancedOptions": "Advanced Options",
"ai": "ai",
@@ -734,6 +735,8 @@
"customConfig": "Custom Config",
"customConfigFileLocation": "Custom Config File Location",
"customSaveLocation": "Custom Save Location",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
@@ -768,6 +771,7 @@
"mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model",
"modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed",
@@ -839,9 +843,12 @@
"statusConverting": "Converting",
"syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"updateModel": "Update Model",
"useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
"v1": "v1",
"v2_768": "v2 (768px)",
"v2_base": "v2 (512px)",
@@ -860,6 +867,7 @@
"models": {
"addLora": "Add LoRA",
"allLoRAsAdded": "All LoRAs added",
"concepts": "Concepts",
"loraAlreadyAdded": "LoRA already added",
"esrganModel": "ESRGAN Model",
"loading": "loading",

View File

@@ -505,8 +505,6 @@
"seamLowThreshold": "Bajo",
"coherencePassHeader": "Parámetros de la coherencia",
"compositingSettingsHeader": "Ajustes de la composición",
"coherenceSteps": "Pasos",
"coherenceStrength": "Fuerza",
"patchmatchDownScaleSize": "Reducir a escala",
"coherenceMode": "Modo"
},

View File

@@ -114,7 +114,8 @@
"checkpoint": "Checkpoint",
"safetensors": "Safetensors",
"ai": "ia",
"file": "File"
"file": "File",
"toResolve": "Da risolvere"
},
"gallery": {
"generations": "Generazioni",
@@ -142,8 +143,6 @@
"copy": "Copia",
"download": "Scarica",
"setCurrentImage": "Imposta come immagine corrente",
"preparingDownload": "Preparazione del download",
"preparingDownloadFailed": "Problema durante la preparazione del download",
"downloadSelection": "Scarica gli elementi selezionati",
"noImageSelected": "Nessuna immagine selezionata",
"deleteSelection": "Elimina la selezione",
@@ -609,8 +608,6 @@
"seamLowThreshold": "Basso",
"seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza",
"coherenceSteps": "Passi",
"coherenceStrength": "Forza",
"compositingSettingsHeader": "Impostazioni di composizione",
"patchmatchDownScaleSize": "Ridimensiona",
"coherenceMode": "Modalità",
@@ -1400,19 +1397,6 @@
"Regola la maschera."
]
},
"compositingCoherenceSteps": {
"heading": "Passi",
"paragraphs": [
"Numero di passi utilizzati nel Passaggio di Coerenza.",
"Simile ai passi di generazione."
]
},
"compositingBlur": {
"heading": "Sfocatura",
"paragraphs": [
"Il raggio di sfocatura della maschera."
]
},
"compositingCoherenceMode": {
"heading": "Modalità",
"paragraphs": [
@@ -1431,13 +1415,6 @@
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
]
},
"compositingStrength": {
"heading": "Forza",
"paragraphs": [
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
"Simile alla forza di riduzione del rumore."
]
},
"paramNegativeConditioning": {
"paragraphs": [
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",

View File

@@ -123,8 +123,6 @@
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
"loading": "불러오는 중",
"unableToLoad": "갤러리를 로드할 수 없음",
"preparingDownload": "다운로드 준비",
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
"singleColumnLayout": "단일 열 레이아웃",
"image": "이미지",
"loadMore": "더 불러오기",

View File

@@ -97,8 +97,6 @@
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
"loading": "Bezig met laden",
"unableToLoad": "Kan galerij niet laden",
"preparingDownload": "Bezig met voorbereiden van download",
"preparingDownloadFailed": "Fout bij voorbereiden van download",
"downloadSelection": "Download selectie",
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
"copy": "Kopieer",
@@ -535,8 +533,6 @@
"coherencePassHeader": "Coherentiestap",
"maskBlur": "Vervaag",
"maskBlurMethod": "Vervagingsmethode",
"coherenceSteps": "Stappen",
"coherenceStrength": "Sterkte",
"seamHighThreshold": "Hoog",
"seamLowThreshold": "Laag",
"invoke": {
@@ -1139,13 +1135,6 @@
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
]
},
"compositingCoherenceSteps": {
"heading": "Stappen",
"paragraphs": [
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
"Gelijk aan de hoofdparameter Stappen."
]
},
"dynamicPrompts": {
"paragraphs": [
"Dynamische prompts vormt een enkele prompt om in vele.",
@@ -1160,12 +1149,6 @@
],
"heading": "VAE"
},
"compositingBlur": {
"heading": "Vervaging",
"paragraphs": [
"De vervagingsstraal van het masker."
]
},
"paramIterations": {
"paragraphs": [
"Het aantal te genereren afbeeldingen.",
@@ -1240,13 +1223,6 @@
],
"heading": "Ontruisingssterkte"
},
"compositingStrength": {
"heading": "Sterkte",
"paragraphs": [
"Ontruisingssterkte voor de coherentiefase.",
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
]
},
"paramNegativeConditioning": {
"paragraphs": [
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",

View File

@@ -143,8 +143,6 @@
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
"loading": "Загрузка",
"unableToLoad": "Невозможно загрузить галерею",
"preparingDownload": "Подготовка к скачиванию",
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
"image": "изображение",
"drop": "перебросить",
"problemDeletingImages": "Проблема с удалением изображений",
@@ -612,9 +610,7 @@
"maskBlurMethod": "Метод размытия",
"seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий",
"coherenceSteps": "Шагов",
"coherencePassHeader": "Порог Coherence",
"coherenceStrength": "Сила",
"compositingSettingsHeader": "Настройки компоновки",
"invoke": {
"noNodesInGraph": "Нет узлов в графе",
@@ -1321,13 +1317,6 @@
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
]
},
"compositingCoherenceSteps": {
"heading": "Шаги",
"paragraphs": [
"Количество шагов снижения шума, используемых при прохождении когерентности.",
"То же, что и основной параметр «Шаги»."
]
},
"dynamicPrompts": {
"paragraphs": [
"Динамические запросы превращают одно приглашение на множество.",
@@ -1342,12 +1331,6 @@
],
"heading": "VAE"
},
"compositingBlur": {
"heading": "Размытие",
"paragraphs": [
"Радиус размытия маски."
]
},
"paramIterations": {
"paragraphs": [
"Количество изображений, которые нужно сгенерировать.",
@@ -1422,13 +1405,6 @@
],
"heading": "Шумоподавление"
},
"compositingStrength": {
"heading": "Сила",
"paragraphs": [
null,
"То же, что параметр «Сила шумоподавления img2img»."
]
},
"paramNegativeConditioning": {
"paragraphs": [
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",

View File

@@ -355,7 +355,6 @@
"starImage": "Yıldız Koy",
"download": "İndir",
"deleteSelection": "Seçileni Sil",
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
"problemDeletingImages": "Görsel Silmede Sorun",
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
"galleryImageResetSize": "Boyutu Resetle",
@@ -377,7 +376,6 @@
"setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi",
"downloadSelection": "Seçileni İndir",
"preparingDownload": "İndirmeye Hazırlanıyor",
"singleColumnLayout": "Tek Sütun Düzen",
"generations": ıktılar",
"showUploads": "Yüklenenleri Göster",
@@ -723,7 +721,6 @@
"clipSkip": "CLIP Atlama",
"randomizeSeed": "Rastgele Tohum",
"cfgScale": "CFG Ölçeği",
"coherenceStrength": "Etki",
"controlNetControlMode": "Yönetim Kipi",
"general": "Genel",
"img2imgStrength": "Görselden Görsel Ölçüsü",
@@ -793,7 +790,6 @@
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
"cfgRescale": "CFG Rescale",
"coherencePassHeader": "Uyum Geçişi",
"coherenceSteps": "Adım",
"infillMethod": "Doldurma Yöntemi",
"maskBlurMethod": "Bulandırma Yöntemi",
"steps": "Adım",

View File

@@ -136,8 +136,6 @@
"copy": "复制",
"download": "下载",
"setCurrentImage": "设为当前图像",
"preparingDownload": "准备下载",
"preparingDownloadFailed": "准备下载时出现问题",
"downloadSelection": "下载所选内容",
"noImageSelected": "无选中的图像",
"deleteSelection": "删除所选内容",
@@ -616,11 +614,9 @@
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
},
"patchmatchDownScaleSize": "缩小",
"coherenceSteps": "步数",
"clipSkip": "CLIP 跳过层",
"compositingSettingsHeader": "合成设置",
"useCpuNoise": "使用 CPU 噪声",
"coherenceStrength": "强度",
"enableNoiseSettings": "启用噪声设置",
"coherenceMode": "模式",
"cpuNoise": "CPU 噪声",
@@ -1402,19 +1398,6 @@
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸SDXL 模型使用等效 1024x1024 的尺寸。"
]
},
"compositingCoherenceSteps": {
"heading": "步数",
"paragraphs": [
"一致性层中使用的去噪步数。",
"与主参数中的步数相同。"
]
},
"compositingBlur": {
"heading": "模糊",
"paragraphs": [
"遮罩模糊半径。"
]
},
"noiseUseCPU": {
"heading": "使用 CPU 噪声",
"paragraphs": [
@@ -1467,13 +1450,6 @@
"第二轮去噪有助于合成内补/外扩图像。"
]
},
"compositingStrength": {
"heading": "强度",
"paragraphs": [
"一致性层使用的去噪强度。",
"去噪强度与图生图的参数相同。"
]
},
"paramNegativeConditioning": {
"paragraphs": [
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",

View File

@@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
// Dynamic prompts
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -0,0 +1,96 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const currentModel = state.generation.model;
if (!currentModel) {
return;
}
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
if (!metadata || !metadata.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
},
});
};

View File

@@ -1,4 +1,5 @@
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt';
@@ -82,6 +83,8 @@ export type AppConfig = {
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model

View File

@@ -59,7 +59,7 @@ const LoRASelect = () => {
return (
<FormControl isDisabled={!options.length}>
<InformationalPopover feature="lora">
<FormLabel>{t('models.lora')} </FormLabel>
<FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover>
<Combobox
placeholder={placeholder}

View File

@@ -15,7 +15,7 @@ const STATUSES = {
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
const { t } = useTranslation();
if (!status) {
if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>;
}

View File

@@ -8,7 +8,7 @@ export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
</Box>
);
};

View File

@@ -0,0 +1,66 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
return {
initialSteps: steps.initial,
initialCfg: guidance.initial,
initialScheduler: scheduler,
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
initialVaePrecision: vaePrecision,
};
});
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => {
return {
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
vaePrecision: {
isEnabled: !isNil(data?.default_settings?.vae_precision),
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
},
scheduler: {
isEnabled: !isNil(data?.default_settings?.scheduler),
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
},
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
cfgScale: {
isEnabled: !isNil(data?.default_settings?.cfg_scale),
value: data?.default_settings?.cfg_scale || initialCfg,
},
cfgRescaleMultiplier: {
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
},
};
}, [
data?.default_settings,
initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
]);
if (isLoading) {
return <Loading />;
}
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
};

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgRescaleMultiplierType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgRescaleMultiplierType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGRescaleMultiplier">
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGScale">
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,50 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterScheduler(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultSchedulerType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
[field]
);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSchedulerType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramScheduler">
<FormLabel>{t('parameters.scheduler')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,147 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
import { DefaultScheduler } from './DefaultScheduler';
import { DefaultSteps } from './DefaultSteps';
import { DefaultVae } from './DefaultVae';
import { DefaultVaePrecision } from './DefaultVaePrecision';
import { SettingToggle } from './SettingToggle';
export interface FormField<T> {
value: T;
isEnabled: boolean;
}
export type DefaultSettingsFormData = {
vae: FormField<string>;
vaePrecision: FormField<string>;
scheduler: FormField<ParameterScheduler>;
steps: FormField<number>;
cfgScale: FormField<number>;
cfgRescaleMultiplier: FormField<number>;
};
export const DefaultSettingsForm = ({
defaultSettingsDefaults,
}: {
defaultSettingsDefaults: DefaultSettingsFormData;
}) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
(data) => {
if (!selectedModelKey) {
return;
}
const body = {
vae: data.vae.isEnabled ? data.vae.value : null,
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
steps: data.steps.isEnabled ? data.steps.value : null,
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
editModelMetadata({
key: selectedModelKey,
body: { default_settings: body },
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[selectedModelKey, dispatch, editModelMetadata, t]
);
return (
<>
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
<Button
size="sm"
leftIcon={<IoPencil />}
colorScheme="invokeYellow"
isDisabled={!formState.isDirty}
onClick={handleSubmit(onSubmit)}
type="submit"
isLoading={isLoading}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={8}>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vae" />
<DefaultVae control={control} name="vae" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vaePrecision" />
<DefaultVaePrecision control={control} name="vaePrecision" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="scheduler" />
<DefaultScheduler control={control} name="scheduler" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="steps" />
<DefaultSteps control={control} name="steps" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgScale" />
<DefaultCfgScale control={control} name="cfgScale" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgRescaleMultiplier" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex>
</Flex>
</Flex>
</>
);
};

View File

@@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSteps = DefaultSettingsFormData['steps'];
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultSteps),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultSteps).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSteps).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramSteps">
<FormLabel>{t('parameters.steps')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@@ -0,0 +1,65 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultVaeType = DefaultSettingsFormData['vae'];
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data }) => {
const modelArray = map(data?.entities);
const compatibleOptions = modelArray
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
},
});
const onChange = useCallback<ComboboxOnChange>(
(v) => {
const newValue = !v?.value ? 'default' : v.value;
const updatedValue = {
...(field.value as DefaultVaeType),
value: newValue,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
}, [compatibleOptions, field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaeType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAE">
<FormLabel>{t('modelManager.vae')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,51 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
const options = [
{ label: 'FP16', value: 'fp16' },
{ label: 'FP32', value: 'fp32' },
];
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterPrecision(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultVaePrecisionType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaePrecisionType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAEPrecision">
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
</FormControl>
);
}

View File

@@ -0,0 +1,28 @@
import { Switch } from '@invoke-ai/ui-library';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const value = useMemo(() => {
return !!(field.value as FormField<T>).isEnabled;
}, [field.value]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
const updatedValue: FormField<T> = {
...(field.value as FormField<T>),
isEnabled: e.target.checked,
};
field.onChange(updatedValue);
},
[field]
);
return <Switch isChecked={value} onChange={onChange} />;
}

View File

@@ -0,0 +1,18 @@
import { Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
return (
<>
<Flex flexDir="column" height="full" gap="3">
<DataViewer label="metadata" data={metadata || {}} />
</Flex>
</>
);
};

View File

@@ -1,9 +1,58 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView';
export const Model = () => {
const { t } = useTranslation();
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<Tabs mt="4" h="100%">
<TabList>
<Tab>{t('modelManager.settings')}</Tab>
<Tab>{t('modelManager.metadata')}</Tab>
</TabList>
<TabPanels h="100%">
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
<TabPanel h="full">
<ModelMetadata />
</TabPanel>
</TabPanels>
</Tabs>
</>
);
};

View File

@@ -1,12 +1,11 @@
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type {
CheckpointModelConfig,
ControlNetModelConfig,
@@ -18,6 +17,7 @@ import type {
VAEModelConfig,
} from 'services/api/types';
import { DefaultSettings } from './DefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { ModelConvert } from './ModelConvert';
@@ -26,7 +26,6 @@ export const ModelView = () => {
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const modelData = useMemo(() => {
if (!data) {
@@ -73,85 +72,56 @@ export const ModelView = () => {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<Flex flexDir="column" h="full">
<Flex w="full" justifyContent="space-between">
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{modelData.name}
</Heading>
{modelData.source && (
<Text variant="subtext">
{t('modelManager.source')}: {modelData.source}
</Text>
)}
</Flex>
<Flex gap={2}>
<Flex flexDir="column" h="full" gap="2">
<Box layerStyle="second" borderRadius="base" p={3}>
<Flex gap="2" justifyContent="flex-end" w="full">
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
{t('modelManager.edit')}
</Button>
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
</Flex>
</Flex>
<Flex flexDir="column" p={2} gap={3}>
<Flex>
<ModelAttrView label="Description" value={modelData.description} />
</Flex>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelSettings')}
</Heading>
<Box layerStyle="second" borderRadius="base" p={3}>
<Flex flexDir="column" gap={3}>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex flexDir="column" gap={3}>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box>
</Flex>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
{metadata && (
<>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelMetadata')}
</Heading>
<Flex h="full" w="full" p={2}>
<DataViewer label="metadata" data={metadata} />
</Flex>
</>
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box>
<Box layerStyle="second" borderRadius="base" p={3}>
<DefaultSettings />
</Box>
</Flex>
);
};

View File

@@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@@ -0,0 +1,36 @@
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixBold } from 'react-icons/pi';
export const NavigateToModelManagerButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const shouldShowButton = useMemo(() => !disabledTabs.includes('modelManager'), [disabledTabs]);
const handleClick = useCallback(() => {
dispatch(setActiveTab('modelManager'));
}, [dispatch]);
if (!shouldShowButton) {
return null;
}
return (
<IconButton
icon={<PiGearSixBold />}
tooltip={t('modelManager.modelManager')}
aria-label={t('modelManager.modelManager')}
onClick={handleClick}
size="sm"
variant="ghost"
{...props}
/>
);
});
NavigateToModelManagerButton.displayName = 'NavigateToModelManagerButton';

View File

@@ -0,0 +1,28 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setDefaultSettings } from 'features/parameters/store/actions';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSparklingFill } from 'react-icons/ri';
export const UseDefaultSettingsButton = () => {
const model = useAppSelector((s) => s.generation.model);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleClickDefaultSettings = useCallback(() => {
dispatch(setDefaultSettings());
}, [dispatch]);
return (
<IconButton
icon={<RiSparklingFill />}
tooltip={t('modelManager.useDefaultSettings')}
aria-label={t('modelManager.useDefaultSettings')}
isDisabled={!model}
onClick={handleClickDefaultSettings}
size="sm"
variant="ghost"
/>
);
};

View File

@@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
export const setDefaultSettings = createAction('generation/setDefaultSettings');

View File

@@ -230,6 +230,12 @@ export const generationSlice = createSlice({
state.height = optimalDimension;
}
}
if (action.payload.sd?.scheduler) {
state.scheduler = action.payload.sd.scheduler;
}
if (action.payload.sd?.vaePrecision) {
state.vaePrecision = action.payload.sd.vaePrecision;
}
});
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling

View File

@@ -1,15 +1,5 @@
import type { FormLabelProps } from '@invoke-ai/ui-library';
import {
Expander,
Flex,
FormControlGroup,
StandaloneAccordion,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
} from '@invoke-ai/ui-library';
import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
@@ -20,7 +10,9 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { filter } from 'lodash-es';
@@ -39,11 +31,11 @@ export const GenerationSettingsAccordion = memo(() => {
() =>
createMemoizedSelector(selectLoraSlice, (lora) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges };
}),
[modelConfig]
[modelConfig, t]
);
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
@@ -58,39 +50,35 @@ export const GenerationSettingsAccordion = memo(() => {
return (
<StandaloneAccordion
label={t('accordions.generation.title')}
badges={accordionBadges}
badges={[...accordionBadges, ...loraTabBadges]}
isOpen={isOpenAccordion}
onToggle={onToggleAccordion}
>
<Tabs variant="collapse">
<TabList>
<Tab>{t('accordions.generation.modelTab')}</Tab>
<Tab badges={loraTabBadges}>{t('accordions.generation.conceptsTab')}</Tab>
</TabList>
<TabPanels>
<TabPanel overflow="visible" px={4} pt={4}>
<Flex gap={4} alignItems="center">
<ParamMainModelSelect />
<Box px={4} pt={4}>
<Flex gap={4} flexDir="column">
<Flex gap={4} alignItems="center">
<ParamMainModelSelect />
<Flex>
<UseDefaultSettingsButton />
<SyncModelsIconButton />
<NavigateToModelManagerButton />
</Flex>
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamScheduler />
<ParamSteps />
<ParamCFGScale />
</FormControlGroup>
</Flex>
</Expander>
</TabPanel>
<TabPanel>
<Flex gap={4} p={4} flexDir="column">
<LoRASelect />
<LoRAList />
</Flex>
</TabPanel>
</TabPanels>
</Tabs>
</Flex>
<Flex gap={4} flexDir="column">
<LoRASelect />
<LoRAList />
</Flex>
</Flex>
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamScheduler />
<ParamSteps />
<ParamCFGScale />
</FormControlGroup>
</Flex>
</Expander>
</Box>
</StandaloneAccordion>
);
});

View File

@@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
boundingBoxHeight: { ...baseDimensionConfig },
scaledBoundingBoxWidth: { ...baseDimensionConfig },
scaledBoundingBoxHeight: { ...baseDimensionConfig },
scheduler: 'euler',
vaePrecision: 'fp32',
steps: {
initial: 30,
sliderMin: 1,

View File

@@ -24,7 +24,15 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
@@ -172,6 +180,16 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => {
return {
@@ -351,6 +369,7 @@ export const {
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
} = modelsApi;
const upsertModelConfigs = (

File diff suppressed because one or more lines are too long

View File

@@ -51,12 +51,12 @@ dependencies = [
"torchmetrics==0.11.4",
"torchsde==0.2.6",
"torchvision==0.16.2",
"transformers==4.37.2",
"transformers==4.38.2",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.10.1",
"fastapi==0.109.2",
"huggingface-hub==0.20.3",
"huggingface-hub==0.21.3",
"pydantic-settings==2.1.0",
"pydantic==2.6.1",
"python-socketio==5.11.1",
@@ -64,6 +64,7 @@ dependencies = [
# Auxiliary dependencies, pinned only if necessary.
"albumentations",
"blake3",
"click",
"datasets",
"Deprecated",
@@ -72,7 +73,6 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"imohash",
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",

View File

@@ -3,6 +3,7 @@ Test the model installer
"""
import platform
import uuid
from pathlib import Path
import pytest
@@ -30,9 +31,8 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
assert key is not None
assert key != "<NOKEY>"
assert len(key) == 32
# Not raising here is sufficient - key should be UUIDv4
uuid.UUID(key, version=4)
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:

96
tests/test_model_hash.py Normal file
View File

@@ -0,0 +1,96 @@
# pyright:reportPrivateUsage=false
from pathlib import Path
from typing import Iterable
import pytest
from blake3 import blake3
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
test_cases: list[tuple[ALGORITHM, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"),
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
(
"sha512",
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
),
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
]
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
file = Path(tmp_path / "test")
file.write_text("model data")
md5 = ModelHash(algorithm).hash(file)
assert md5 == expected_hash
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
model_hash = ModelHash(algorithm)
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
for f in files:
f.write_text("data")
md5 = model_hash.hash(tmp_path)
# Manual implementation of composite hash - always uses BLAKE3
composite_hasher = blake3()
for f in files:
h = model_hash.hash(f)
composite_hasher.update(h.encode("utf-8"))
assert md5 == composite_hasher.hexdigest()
def test_model_hash_raises_error_on_invalid_algorithm():
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
return {str(p) for p in paths}
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should be ignored - hash should not change
file = tmp_path / "test.icecream"
file.write_text("data")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should not be ignored - hash should change
file = tmp_path / "test.bin"
file.write_text("more data")
model_files.add(file)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
def test_model_hash_uses_custom_filter(tmp_path: Path):
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
def file_filter(file_path: str) -> bool:
return file_path.endswith(".pickme")
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}