Compare commits

..

119 Commits

Author SHA1 Message Date
Lincoln Stein
5206ddf9b2 truncate long prompts to avoid a crash with controlnet 2023-07-15 23:49:25 -04:00
Lincoln Stein
b767b5d44c user must adjust terminal size on Windows 2023-07-15 23:19:50 -04:00
Lincoln Stein
72c891bbac remove conhost from windows install process 2023-07-15 21:48:04 -04:00
Lincoln Stein
39e66ec934 rebuild front end 2023-07-15 20:32:22 -04:00
Lincoln Stein
eda1c94bd6 fix default launcher option 2023-07-15 20:22:12 -04:00
Lincoln Stein
e95cb3aa71 Merge branch 'lstein/default-model-install' into release/invokeai-3-0-beta 2023-07-15 20:16:51 -04:00
Lincoln Stein
373beefd13 remove restoration option from invokeai.yaml 2023-07-15 18:26:19 -04:00
Lincoln Stein
6b0a158ffa Merge branch 'main' into lstein/default-model-install 2023-07-15 18:23:34 -04:00
Lincoln Stein
c90345d6a3 deprecate the face restoration option 2023-07-15 18:23:32 -04:00
Lincoln Stein
bcb962f01f Setup textual inversion training with new model manager (#3775) 2023-07-15 18:23:11 -04:00
Lincoln Stein
70b12d9693 Merge branch 'main' into update-textual-inversion-training 2023-07-15 18:16:20 -04:00
Lincoln Stein
9faffa2245 revert inadvertent breaking change to config causing test failures (override) 2023-07-15 18:15:59 -04:00
Lincoln Stein
f66ead0819 Merge branch 'main' into update-textual-inversion-training 2023-07-15 17:44:45 -04:00
Lincoln Stein
0a121e139f add documentation on the configuration system (#3787)
This PR adds a new documentation file that describes how to configure
InvokeAI using `invokeai.yaml`, etc.
2023-07-15 17:39:40 -04:00
Lincoln Stein
6073cb8020 add documentation on the configuration system 2023-07-15 16:14:47 -04:00
psychedelicious
c7b547ea3e feat(nodes): remove references to restoration services
- remove restoration services
- remove the restore faces nodes
- update tests
2023-07-16 01:12:39 +10:00
psychedelicious
8a1b9d1001 chore(ui): regen types 2023-07-16 01:06:57 +10:00
psychedelicious
74ca87ac9e feat(nodes): add realesrgan node 2023-07-16 01:06:50 +10:00
Kent Keirsey
610c3a4512 merge main + migrate script now initializes destination root if needed (#3784) 2023-07-15 10:39:04 -04:00
Kent Keirsey
77b0129b4c Merge branch 'main' into lstein/migrate-fix 2023-07-15 10:37:56 -04:00
Lincoln Stein
e01706f5f5 add fp16 support to controlnet models 2023-07-15 10:37:11 -04:00
Lincoln Stein
f504c7ebbd Merge branch 'main' into lstein/migrate-fix 2023-07-15 10:13:44 -04:00
Lincoln Stein
a111539059 migrate script now initializes destination root if needed 2023-07-15 09:59:34 -04:00
Kent Keirsey
52948a1bbc Update CODEOWNERS 2023-07-15 08:52:41 -04:00
Lincoln Stein
32e7e52d69 Merge branch 'main' into lstein/default-model-install 2023-07-15 08:30:22 -04:00
blessedcoolant
8fc765694e fix: Minor UI tweak to Control Net enable button (#3783) 2023-07-16 00:28:15 +12:00
blessedcoolant
ff74de7a60 fix: Minor UI tweak to Control Net enable button 2023-07-16 00:27:52 +12:00
Lincoln Stein
07a2da40b8 Rewrite controlnet to new model manager (#3665) 2023-07-15 08:24:06 -04:00
psychedelicious
d234bf1cb9 feat(install): display full ESRGAN model filenames during installation 2023-07-15 21:27:10 +10:00
psychedelicious
f7230d07db feat(ui): fix controlnet image preview alignment 2023-07-15 20:49:03 +10:00
psychedelicious
b265956083 fix(ui): disable drop when controlnet disabled 2023-07-15 20:47:02 +10:00
psychedelicious
8e0ba24bf2 feat(ui): fix cnet ui alignment 2023-07-15 20:36:32 +10:00
psychedelicious
be4705ec32 feat(ui): move control mode and processor to main view 2023-07-15 20:34:26 +10:00
psychedelicious
4ac0ce59fb fix(ui): add custom label to IAIMantineSelects
needed to have their label styles match chakras
2023-07-15 20:29:15 +10:00
psychedelicious
7daafc03d3 fix(ui): fix invoke button styles when processing 2023-07-15 20:04:33 +10:00
psychedelicious
457e4b7fc5 feat(ui): tweak slider label spacing 2023-07-15 19:56:45 +10:00
psychedelicious
d1ecd007ab feat(ui): promote controlnet to be just under general
It is the most impactful feature, and also takes up the most space when you expand it. Promoted.
2023-07-15 19:56:45 +10:00
psychedelicious
7dec2d09f0 feat(ui): disable specific controlnet inputs when that controlnet is disabled
The UX is clearer now, but it's still easy to miss that your individual controlnets are enabled, but the overall controlnet feature is disabled.
2023-07-15 19:56:45 +10:00
psychedelicious
13d182ead2 feat(ui): move cnet add button to top of list 2023-07-15 19:56:45 +10:00
psychedelicious
401727b0c9 feat(ui): add cnet advanced tooltip 2023-07-15 19:56:45 +10:00
psychedelicious
19e076cd15 fix(ui): fix no controlnet model selected by default 2023-07-15 19:56:45 +10:00
psychedelicious
8a14c5db00 feat(ui): wip controlnet layout 2023-07-15 19:56:45 +10:00
psychedelicious
77ad3c959b feat(ui): tweak slider styles 2023-07-15 19:56:45 +10:00
psychedelicious
952a7a8674 feat(ui): do not autoprocess if user just disabled autoconfig 2023-07-15 19:56:45 +10:00
psychedelicious
7b6d91c69f feat(ui): control net UI weights 0 to 2 2023-07-15 19:56:44 +10:00
psychedelicious
8f66d826a5 feat(ui): refactor controlnet UI components to use local memoized selectors
makes them more portable and easier to reason about
2023-07-15 19:56:44 +10:00
psychedelicious
d270f21c85 feat(nodes): valid controlnet weights are -1 to 2 2023-07-15 19:56:44 +10:00
psychedelicious
ae72f372be fix(nodes): do not use hardcoded controlnet model 2023-07-15 19:56:44 +10:00
psychedelicious
0d41346417 feat(ui): fix controlNet models
- update controlnet state to use object format for model
- update model-parsing helper functions to log errors
- update nodes components, types and state
- remove controlnets from state when models are loaded and the controlnet's model is not available
2023-07-15 19:56:44 +10:00
Mary Hipp
76dc47e88d remove frontend constants, use backend response for controlnet models. add disabled state if base model is not compatible. clear control net model if main base model changes. add logic to guess processor and move it up in UI 2023-07-15 19:56:44 +10:00
psychedelicious
5ac114576f feat(ui): add controlnet field to nodes 2023-07-15 19:56:44 +10:00
psychedelicious
29b2e59e65 fix(nodes): fix ref to ctx mgr service, missing import 2023-07-15 19:56:44 +10:00
psychedelicious
96c9db6d2e chore(ui): typegen 2023-07-15 19:56:44 +10:00
psychedelicious
82fa39b531 feat(nodes): add controlnet nodes type hint 2023-07-15 19:56:44 +10:00
psychedelicious
788dcbde70 fix(nodes): add missing import 2023-07-15 19:56:44 +10:00
Sergey Borisov
6ab9a5e108 Draft 2023-07-15 19:56:44 +10:00
blessedcoolant
5d5a497ed4 Model manager route enhancements (#3768)
# Multiple enhancements to model manager REACT API
    
1. add a `/sync` route for synchronizing the in-memory model lists to
models.yaml, the models directory, and the autoimport directories.
2. added optional destination directories to convert_model and
merge_model operations.
3. added a `/ckpt_confs` route for retrieving known legacy checkpoint
configuration files.
4. added a `/search` route for finding all models in a directory located
in the server filesystem
5. added a `/add`  route for manual addition of a local models
6. added a `/rename` route for renaming and/or rebasing models
7. changed the path of the `import_model` route to `/import`

# Slightly annoying detail:

When adding a model manually using `/add`, the body JSON must exactly
match one of the model configurations returned by `list_models` (i.e.
there is no defaulting of fields). This includes the `error` field,
which should be set to "null".
2023-07-15 17:03:40 +12:00
blessedcoolant
808b2de709 Merge branch 'main' into lstein/model-manager-route-enhancements 2023-07-15 16:56:54 +12:00
Lincoln Stein
2faa7cee37 add rename_model route 2023-07-14 23:03:18 -04:00
Brandon
467414f214 Merge branch 'main' into update-textual-inversion-training 2023-07-14 22:32:09 -04:00
Mary Hipp
194434dbfa restore scrollbar 2023-07-15 12:25:28 +10:00
Brandon Rising
f88a338be0 Setup textual inversion training with new model manager 2023-07-14 21:58:51 -04:00
psychedelicious
8cb19578c2 fix(ui): fix crash on LoRA remove / weight change 2023-07-15 11:09:18 +10:00
blessedcoolant
54b813c981 fix(ui): fix mouse interactions (#3764) 2023-07-15 12:45:59 +12:00
blessedcoolant
c4a6f25717 Merge branch 'main' into fix/nodes/fix-mouse-interactions 2023-07-15 12:44:49 +12:00
Lincoln Stein
b306247eb5 remove clipseg model install 2023-07-14 20:39:42 -04:00
blessedcoolant
5aade31c22 Pad conditionings using zeros and encoder_attention_mask (#3772) 2023-07-15 12:35:10 +12:00
Lincoln Stein
a45f7ce355 add --list-models command 2023-07-14 19:52:47 -04:00
Lincoln Stein
eb9d74653d set default models for realesrgan, controlnet and text inversion 2023-07-14 19:03:41 -04:00
Sergey Borisov
7093e5d033 Pad conditionings using zeros and encoder_attention_mask 2023-07-15 00:52:54 +03:00
Kent Keirsey
565299c7a1 Minor Updates to NODES.md 2023-07-14 16:10:49 -04:00
ymgenesis
44662c0c0e remove incorrect hyperlink 2023-07-14 16:10:49 -04:00
ymgenesis
ba783d9466 add node grouping concepts & typos
Added a section that focuses more on conceptualizing node workflows as groups of nodes working together as a whole. Also, typos.
2023-07-14 16:10:49 -04:00
ymgenesis
759ca317d0 add node types & functions
Add a list of node types and their functions. Functions are just copied text from node descriptions in webui.
2023-07-14 16:10:49 -04:00
ymgenesis
3454b7654c add i2i and controlnet examples
Added examples for img2img and ControlNet
2023-07-14 16:10:49 -04:00
ymgenesis
19cbda56b6 Create NODES.md
Introductory nodes documentation
2023-07-14 16:10:49 -04:00
Lincoln Stein
e71ce83e9c Merge branch 'main' into lstein/model-manager-route-enhancements 2023-07-14 13:52:55 -04:00
Lincoln Stein
8600aad12b multiple enhancements to model manager REACT API
1. add a /sync route for synchronizing the in-memory model lists to
   models.yaml, the models directory, and the autoimport directories.

2. add optional destination_directories to convert_model and merge_model
   operations.

3. add /ckpt_confs route for retrieving known legacy checkpoint configuration
   files.

4. add /search route for finding all models in a directory located in the server
   filesystem
2023-07-14 13:45:16 -04:00
Jonathan
9960d7ca2a Allow ImageResizeInvocation w/h to take inputs from other nodes (#3765) 2023-07-15 05:34:13 +12:00
blessedcoolant
48561908b1 Merge branch 'main' into fix/nodes/fix-mouse-interactions 2023-07-15 04:13:46 +12:00
blessedcoolant
3210f96967 Model Manager 3.0 UI (#3735)
DONE:

-  Restore Update Model functionality
- Restore Delete Model functionality
- Restore Model Convert functionality
- Restore Model Merge functionality
- Refine UX (fine tweaks when everything is done - TODO)

TODO

- Add Model (will be finished in a future PR once the backend work is
done)
2023-07-15 04:13:31 +12:00
Lincoln Stein
ad076b1174 add model directory search route 2023-07-14 11:14:33 -04:00
psychedelicious
f6752965b7 fix(ui): allow decimals in number inputs
still some jank but eh
2023-07-15 01:05:10 +10:00
psychedelicious
30e45eaf47 feat(ui): hold shift to make nodes draggable from anywhere 2023-07-15 00:45:26 +10:00
psychedelicious
0257b4a611 fix(ui): fix mouse interactions 2023-07-15 00:13:45 +10:00
blessedcoolant
3c7cf72423 fix: Clean up merge models submit handler 2023-07-15 01:29:51 +12:00
blessedcoolant
2a533b275f Merge branch 'main' into mm-ui 2023-07-15 01:24:40 +12:00
blessedcoolant
25d07891b5 Merge branch 'mm-ui' of https://github.com/blessedcoolant/InvokeAI into mm-ui 2023-07-15 01:24:20 +12:00
blessedcoolant
401fa6deb5 fix: Misc fixes 2023-07-15 01:23:08 +12:00
psychedelicious
f68ab55d6b fix(ui): fix missing mantineTheme, fixes fonts 2023-07-14 23:16:05 +10:00
psychedelicious
79d65125c2 feat(ui): extract mantine component styles to hook, add less opinionated mantine components
IAIMantineSelect and IAIMantineMultiSelect have a bit of extra logic that prevents simple select functionality from working as expected.

- extract the styles into hooks
- rename those two components to IAIMantineSearchableSelect and IAIMantineSearchableMultiSelect
- Create IAIMantineSelect (which is just a dropdown) and use it in model manager and a few other places

When we only have a few options to present and searching is not efficient, we should use this instead.
2023-07-14 23:00:38 +10:00
psychedelicious
d4dfd84525 feat(ui): mm colors 2023-07-14 20:12:02 +10:00
psychedelicious
eb2a7058bf feat(ui): tweak fontSize in modellist 2023-07-14 19:49:05 +10:00
psychedelicious
56d209842f feat(ui): only show modellistitem when none in array 2023-07-14 19:46:18 +10:00
psychedelicious
0b2f0c05b2 fix(ui): fix selecting model does not update form 2023-07-14 19:31:52 +10:00
psychedelicious
1e5ae9d986 feat(ui): refactor model manager ui
- simplify UI logic in `ModelManagerPanel` components
- fix up the types a bit to make it easier to select models
- remove `openModel` state, just make it a useState since it is very local to model manager
2023-07-14 19:22:37 +10:00
psychedelicious
f2af82bf73 feat(ui): add model convert for success/failure handling 2023-07-14 17:39:00 +10:00
psychedelicious
6d7fb49a7a fix(ui): fix model edit button disabled status 2023-07-14 17:36:10 +10:00
psychedelicious
48a8bd4985 feat(ui): add model update for success/failure handling 2023-07-14 17:35:45 +10:00
psychedelicious
d8437d3036 feat(ui): add simple selectIsBusy selector 2023-07-14 17:34:34 +10:00
psychedelicious
a0cb18a12c feat(ui): refetch models on socket connect 2023-07-14 17:34:13 +10:00
psychedelicious
b2005d821a fix(ui): fix types for models queries 2023-07-14 16:59:31 +10:00
psychedelicious
66b12ab0ea fix(ui): do not blacklist the rtk query events
doing so breaks the devtools
2023-07-14 16:59:13 +10:00
blessedcoolant
834774ce4c fix: Merge Conflicts 2023-07-14 18:16:34 +12:00
blessedcoolant
7cd60214cb Merge branch 'main' into mm-ui 2023-07-14 18:14:45 +12:00
psychedelicious
d858a0c5d8 fix(ui): fix rtk tags
I had mixed up `type` and `id` on a bunch of the tags. Fixing those
2023-07-14 15:32:09 +10:00
blessedcoolant
abe2a0f9b4 fix: merge conflicts (name renamed to model_name) for models 2023-07-14 15:53:28 +12:00
blessedcoolant
16e93c6455 Merge branch 'main' into mm-ui 2023-07-14 15:46:53 +12:00
blessedcoolant
1c2144794c Merge branch 'main' into mm-ui 2023-07-13 13:58:22 +12:00
blessedcoolant
71e34ac256 Merge branch 'main' into mm-ui 2023-07-13 12:48:43 +12:00
blessedcoolant
31bb4bfc61 style: Update Model Manager Styling to new format 2023-07-12 23:12:12 +12:00
blessedcoolant
3db1aa738c feat: Restore Model Merge functionality 2023-07-12 22:43:06 +12:00
blessedcoolant
683229e285 fix: Update model convert toast message 2023-07-12 20:44:57 +12:00
blessedcoolant
2cedf6aed5 feat: Restore Model Convert Functionality 2023-07-12 20:40:58 +12:00
blessedcoolant
6238a53fdd feat: Add basic form validation for path input 2023-07-12 20:11:05 +12:00
blessedcoolant
310e401b03 feat: Create basic IAIMantineTextInput component for form usage 2023-07-12 20:10:33 +12:00
blessedcoolant
3568e28b1c fix: Type resolutions & Bug Fixes
- Fix checkpoint filter not working
- Resolve all typescript and undefined issues in Model Manager List / Edit Forms and main panel
2023-07-12 19:05:16 +12:00
blessedcoolant
5a6ad99d4e feat: Restore Delete Model Functionality 2023-07-12 16:39:07 +12:00
blessedcoolant
afb46564e8 feat: Restore Update Model functionality 2023-07-12 16:13:49 +12:00
170 changed files with 4435 additions and 4038 deletions

4
.github/CODEOWNERS vendored
View File

@@ -6,7 +6,7 @@
/mkdocs.yml @lstein @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising
# installation and configuration
/pyproject.toml @lstein @blessedcoolant
@@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising
# front ends
/invokeai/frontend/CLI @lstein

View File

@@ -0,0 +1,287 @@
---
title: Configuration
---
# :material-tune-variant: InvokeAI Configuration
## Intro
InvokeAI has numerous runtime settings which can be used to adjust
many aspects of its operations, including the location of files and
directories, memory usage, and performance. These settings can be
viewed and customized in several ways:
1. By editing settings in the `invokeai.yaml` file.
2. By setting environment variables.
3. On the command-line, when InvokeAI is launched.
In addition, the most commonly changed settings are accessible
graphically via the `invokeai-configure` script.
### How the Configuration System Works
When InvokeAI is launched, the very first thing it needs to do is to
find its "root" directory, which contains its configuration files,
installed models, its database of images, and the folder(s) of
generated images themselves. In this document, the root directory will
be referred to as ROOT.
#### Finding the Root Directory
To find its root directory, InvokeAI uses the following recipe:
1. It first looks for the argument `--root <path>` on the command line
it was launched from, and uses the indicated path if present.
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
the directory path found there if present.
3. If neither of these are present, then InvokeAI looks for the
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
4. Finally, InvokeAI looks for a directory in the current user's home
directory named `invokeai`.
#### Reading the InvokeAI Configuration File
Once the root directory has been located, InvokeAI looks for a file
named `ROOT/invokeai.yaml`, and if present reads configuration values
from it. The top of this file looks like this:
```
InvokeAI:
Web Server:
host: localhost
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: false
patchmatch: true
restore: true
...
```
This lines in this file are used to establish default values for
Invoke's settings. In the above fragment, the Web Server's listening
port is set to 9090 by the `port` setting.
You can edit this file with a text editor such as "Notepad" (do not
use Word or any other word processor). When editing, be careful to
maintain the indentation, and do not add extraneous text, as syntax
errors will prevent InvokeAI from launching. A basic guide to the
format of YAML files can be found
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
You can fix a broken `invokeai.yaml` by deleting it and running the
configuration script again -- option [7] in the launcher, "Re-run the
configure script".
#### Reading Environment Variables
Next InvokeAI looks for defined environment variables in the format
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
variable values take precedence over configuration file variables. On
a Macintosh system, for example, you could change the port that the
web server listens on by setting the environment variable this way:
```
export INVOKEAI_port=8000
invokeai-web
```
Please check out these
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
and
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
guides for setting temporary and permanent environment variables.
#### Reading the Command Line
Lastly, InvokeAI takes settings from the command line, which override
everything else. The command-line settings have the same name as the
corresponding configuration file settings, preceded by a `--`, for
example `--port 8000`.
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
InvokeAI, then just pass the command-line arguments to the launcher:
```
invoke.bat --port 8000 --host 0.0.0.0
```
The arguments will be applied when you select the web server option
(and the other options as well).
If, on the other hand, you prefer to launch InvokeAI directly from the
command line, you would first activate the virtual environment (known
as the "developer's console" in the launcher), and run `invokeai-web`:
```
> C:\Users\Fred\invokeai\.venv\scripts\activate
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
```
You can get a listing and brief instructions for each of the
command-line options by giving the `--help` argument:
```
(.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
```
## The Configuration Settings
The configuration settings are divided into several distinct
groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
### Features
These configuration settings allow you to enable and disable various InvokeAI features:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
### Memory/Performance
These options tune InvokeAI's memory and performance characteristics.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
### Paths
These options set the paths of various directories and files used by
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
`autoimport/main`, then the corresponding directory will be located at
`/home/fred/invokeai/autoimport/main`.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
Note that the autoimport directories will be searched recursively,
allowing you to organize the models into folders and subfolders in any
way you wish. In addition, while we have split up autoimport
directories by the type of model they contain, this isn't
necessary. You can combine different model types in the same folder
and InvokeAI will figure out what they are. So you can easily use just
one autoimport directory by commenting out the unneeded paths:
```
Paths:
autoimport_dir: autoimport
# lora_dir: null
# embedding_dir: null
# controlnet_dir: null
```
### Logging
These settings control the information, warning, and debugging
messages printed to the console log while InvokeAI is running:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
```
log_handlers:
- console
- syslog=localhost
- file=/var/log/invokeai.log
```
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
* `syslog` is only available on Linux and Macintosh systems. It uses
the operating system's "syslog" facility to write log file entries
locally or to a remote logging machine. `syslog` offers a variety
of configuration options:
```
syslog=/dev/log` - log to the /dev/log device
syslog=localhost` - log to the network logger running on the local machine
syslog=localhost:512` - same as above, but using a non-standard port
syslog=fredserver,facility=LOG_USER,socktype=SOCK_DRAM`
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
```
* `http` can be used to log to a remote web server. The server must be
properly configured to receive and act on log messages. The option
accepts the URL to the web server, and a `method` argument
indicating whether the message should be submitted using the GET or
POST method.
```
http=http://my.server/path/to/logger,method=POST
```
The `log_format` option provides several alternative formats:
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
* `plain` - same as above, but monochrome text only
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.

202
docs/features/NODES.md Normal file
View File

@@ -0,0 +1,202 @@
# Nodes Editor (Experimental Beta)
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
To better understand how nodes are used, think of how an electric power bar works. It takes in one input (electricity from a wall outlet) and passes it to multiple devices through multiple outputs. Similarly, a node could have multiple inputs and outputs functioning at the same (or different) time, but all node outputs pass information onward like a power bar passes electricity. Not all outputs are compatible with all inputs, however - Each node has different constraints on how it is expecting to input/output information. In general, node outputs are colour-coded to match compatible inputs of other nodes.
## Anatomy of a Node
Individual nodes are made up of the following:
- Inputs: Edge points on the left side of the node window where you connect outputs from other nodes.
- Outputs: Edge points on the right side of the node window where you connect to inputs on other nodes.
- Options: Various options which are either manually configured, or overridden by connecting an output from another node to the input.
## Diffusion Overview
Taking the time to understand the diffusion process will help you to understand how to set up your nodes in the nodes editor.
There are two main spaces Stable Diffusion works in: image space and latent space.
Image space represents images in pixel form that you look at. Latent space represents compressed inputs. Its in latent space that Stable Diffusion processes images. A VAE (Variational Auto Encoder) is responsible for compressing and encoding inputs into latent space, as well as decoding outputs back into image space.
When you generate an image using text-to-image, multiple steps occur in latent space:
1. Random noise is generated at the chosen height and width. The noises characteristics are dictated by the chosen (or not chosen) seed. This noise tensor is passed into latent space. Well call this noise A.
1. Using a models U-Net, a noise predictor examines noise A, and the words tokenized by CLIP from your prompt (conditioning). It generates its own noise tensor to predict what the final image might look like in latent space. Well call this noise B.
1. Noise B is subtracted from noise A in an attempt to create a final latent image indicative of the inputs. This step is repeated for the number of sampler steps chosen.
1. The VAE decodes the final latent image from latent space into image space.
image-to-image is a similar process, with only step 1 being different:
1. The input image is decoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how much noise is added, 0 being none, and 1 being all-encompassing. Well call this noise A. The process is then the same as steps 2-4 in the text-to-image explanation above.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).
A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from the latent image across the sampler steps chosen (step 3 above). Less noise is usually subtracted at higher sampler steps.
## Node Types (Base Nodes)
| Node <img width=160 align="right"> | Function |
| ---------------------------------- | --------------------------------------------------------------------------------------|
| Add | Adds two numbers |
| CannyImageProcessor | Canny edge detection for ControlNet |
| ClipSkip | Skip layers in clip text_encoder model |
| Collect | Collects values into a collection |
| Prompt (Compel) | Parse prompt using compel package to conditioning |
| ContentShuffleImageProcessor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes |
| CvInpaint | Simple inpaint using opencv |
| Divide | Divides two numbers |
| DynamicPrompt | Parses a prompt using adieyal/dynamic prompt's random or combinatorial generator |
| FloatLinearRange | Creates a range |
| HedImageProcessor | Applies HED edge detection to image |
| ImageBlur | Blurs an image |
| ImageChannel | Gets a channel from an image |
| ImageCollection | Load a collection of images and provide it as output |
| ImageConvert | Converts an image to a different mode |
| ImageCrop | Crops an image to a specified box. The box can be outside of the image. |
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
| ImageLerp | Linear interpolation of all pixels of an image |
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
| ImagePaste | Pastes an image into another image |
| ImageProcessor | Base class for invocations that reprocess images for ControlNet |
| ImageResize | Resizes an image to specific dimensions |
| ImageScale | Scales an image by a factor |
| ImageToLatents | Scales latents by a given factor |
| InfillColor | Infills transparent areas of an image with a solid color |
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
| InfillTile | Infills transparent areas of an image with tiles of the image |
| Inpaint | Generates an image using inpaint |
| Iterate | Iterates over a list of items |
| LatentsToImage | Generates an image from latents |
| LatentsToLatents | Generates latents using latents as base image |
| LeresImageProcessor | Applies leres processing to image |
| LineartAnimeImageProcessor | Applies line art anime processing to image |
| LineartImageProcessor | Applies line art processing to image |
| LoadImage | Load an image and provide it as output |
| Lora Loader | Apply selected lora to unet and text_encoder |
| Model Loader | Loads a main model, outputting its submodels |
| MaskFromAlpha | Extracts the alpha channel of an image as a mask |
| MediapipeFaceProcessor | Applies mediapipe face processing to image |
| MidasDepthImageProcessor | Applies Midas depth processing to image |
| MlsdImageProcessor | Applied MLSD processing to image |
| Multiply | Multiplies two numbers |
| Noise | Generates latent noise |
| NormalbaeImageProcessor | Applies NormalBAE processing to image |
| OpenposeImageProcessor | Applies Openpose processing to image |
| ParamFloat | A float parameter |
| ParamInt | An integer parameter |
| PidiImageProcessor | Applies PIDI processing to an image |
| Progress Image | Displays the progress image in the Node Editor |
| RandomInit | Outputs a single random integer |
| RandomRange | Creates a collection of random numbers |
| Range | Creates a range of numbers from start to stop with step |
| RangeOfSize | Creates a range from start to start + size with step |
| ResizeLatents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. |
| RestoreFace | Restores faces in the image |
| ScaleLatents | Scales latents by a given factor |
| SegmentAnythingProcessor | Applies segment anything processing to image |
| ShowImage | Displays a provided image, and passes it forward in the pipeline |
| StepParamEasing | Experimental per-step parameter for easing for denoising steps |
| Subtract | Subtracts two numbers |
| TextToLatents | Generates latents from conditionings |
| TileResampleProcessor | Bass class for invocations that preprocess images for ControlNet |
| Upscale | Upscales an image |
| VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput |
| ZoeDepthImageProcessor | Applies Zoe depth processing to image |
## Node Grouping Concepts
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
### Noise
As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input.
<img width="654" alt="groupsnoise" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/2e8d297e-ad55-4d27-bc93-c119dad2a2c5">
### Conditioning
As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node.
<img width="1024" alt="groupsconditioning" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/f8f7ad8a-8d9c-418e-b5ad-1437b774b27e">
### Image Space & VAE
The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space.
<img width="637" alt="groupsimgvae" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/dd99969c-e0a8-4f78-9b17-3ffe179cef9a">
### Defined & Random Seeds
It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
<img width="922" alt="groupsrandseed" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/af55bc20-60f6-438e-aba5-3ec871443710">
### Control
Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet.
<img width="805" alt="groupscontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/cc9c5de7-23a7-46c8-bbad-1f3609d999a6">
### LoRA
The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction.
<img width="993" alt="groupslora" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/630962b0-d914-4505-b3ea-ccae9b0269da">
### Scaling
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
<img width="644" alt="groupsallscale" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/99314f05-dd9f-4b6d-b378-31de55346a13">
### Iteration + Multiple Images as Input
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time.
<img width="788" alt="groupsiterate" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/4af5ca27-82c9-4018-8c5b-024d3ee0a121">
### Multiple Image Generation + Random Seeds
Multiple image generation in the node editor is done using the RandomRange node. In this case, the 'Size' field represents the number of images to generate. As RandomRange produces a collection of integers, we need to add the Iterate node to iterate through the collection.
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
<img width="1027" alt="groupsmultigenseeding" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/518d1b2b-fed1-416b-a052-ab06552521b3">
## Examples
With our knowledge of node grouping and the diffusion process, lets break down some basic graphs in the nodes editor. Note that a node's options can be overridden by inputs from other nodes. These examples aren't strict rules to follow and only demonstrate some basic configurations.
### Basic text-to-image Node Graph
<img width="875" alt="nodest2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/17c67720-c376-4db8-94f0-5e00381a61ee">
- Model Loader: A necessity to generating images (as weve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE.
- Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, dog), one will output negative (what you dont want, cat). They both input the CLIP tokenizer that the Model Loader node outputs.
- Noise: Consider this noise A from step one of the text-to-image explanation above. Choose a seed number, width, and height.
- TextToLatents: This node takes many inputs for converting and processing text & noise from image space into latent space, hence the name TextTo**Latents**. In this setup, it inputs positive and negative conditioning from the prompt nodes for processing (step 2 above). It inputs noise from the noise node for processing (steps 2 & 3 above). Lastly, it inputs a U-Net from the Model Loader node for processing (step 2 above). It outputs latents for use in the next LatentsToImage node. Choose number of sampler steps, CFG scale, and scheduler.
- LatentsToImage: This node takes in processed latents from the TextToLatents node, and the models VAE from the Model Loader node which is responsible for decoding latents back into the image space, hence the name LatentsTo**Image**. This node is the last stop, and once the image is decoded, it is saved to the gallery.
### Basic image-to-image Node Graph
<img width="998" alt="nodesi2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/3f2c95d5-cee7-4415-9b79-b46ee60a92fe">
- Model Loader: Choose a model from the dropdown.
- Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before.
- ImageToLatents: Upload a source image directly in the node window, via drag'n'drop from the gallery, or passed in as input. The ImageToLatents node inputs the VAE from the Model Loader node to decode the chosen image from image space into latent space, hence the name ImageTo**Latents**. It outputs latents for use in the next LatentsToLatents node. It also outputs the source image's width and height for use in the next Noise node if the final image is to be the same dimensions as the source image.
- Noise: A noise tensor is created with the width and height of the source image, and connected to the next LatentsToLatents node. Notice the width and height fields are overridden by the input from the ImageToLatents width and height outputs.
- LatentsToLatents: The inputs and options are nearly identical to TextToLatents, except that LatentsToLatents also takes latents as an input. Considering our source image is already converted to latents in the last ImageToLatents node, and text + noise are no longer the only inputs to process, we use the LatentsToLatents node.
- LatentsToImage: Like previously, the LatentsToImage node will use the VAE from the Model Loader as input to decode the latents from LatentsToLatents into image space, and save it to the gallery.
### Basic ControlNet Node Graph
<img width="703" alt="nodescontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/b02ded86-ceb4-44a2-9910-e19ad184d471">
- Model Loader
- Prompt (Compel)
- Noise: Width and height of the CannyImageProcessor ControlNet image is passed in to set the dimensions of the noise passed to TextToLatents.
- CannyImageProcessor: The CannyImageProcessor node is used to process the source image being used as a ControlNet. Each ControlNet processor node applies control in different ways, and has some different options to configure. Width and height are passed to noise, as mentioned. The processed ControlNet image is output to the ControlNet node.
- ControlNet: Select the type of control model. In this case, canny is chosen as the CannyImageProcessor was used to generate the ControlNet image. Configure the control node options, and pass the control output to TextToLatents.
- TextToLatents: Similar to the basic text-to-image example, except ControlNet is passed to the control input edge point.
- LatentsToImage

View File

@@ -153,6 +153,9 @@ This method is recommended for those familiar with running Docker containers
- [Prompt Syntax](features/PROMPTS.md)
- [Generating Variations](features/VARIATIONS.md)
### InvokeAI Configuration
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
## :octicons-log-16: Important Changes Since Version 2.3
### Nodes

View File

@@ -24,7 +24,8 @@ read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
if ! git tag $VERSION ; then
git push origin :refs/tags/$VERSION
if ! git tag -fa $VERSION ; then
echo "Existing/invalid tag"
exit -1
fi

View File

@@ -38,7 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo.
echo See %INSTRUCTIONS% for more details.
echo.
echo "For the best user experience we suggest enlarging or maximizing this window now."
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
pause
@rem ---------------------------- check Python version ---------------

View File

@@ -19,7 +19,7 @@ echo 8. Open the developer console
echo 9. Update InvokeAI
echo 10. Command-line help
echo Q - Quit
set /P choice="Please enter 1-10, Q: [2] "
set /P choice="Please enter 1-10, Q: [1] "
if not defined choice set choice=1
IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI..

View File

@@ -20,7 +20,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@@ -58,7 +57,7 @@ class ApiDependencies:
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}')
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@@ -117,7 +116,7 @@ class ApiDependencies:
)
services = InvocationServices(
model_manager=ModelManagerService(config,logger),
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
@@ -129,7 +128,6 @@ class ApiDependencies:
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response
@@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@@ -78,7 +80,7 @@ async def update_model(
return model_response
@models_router.post(
"/",
"/import",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
@@ -94,7 +96,7 @@ async def import_model(
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
@@ -126,18 +128,100 @@ async def import_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
)
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: {
"description": "Model deleted successfully"
},
404: {
"description": "Model not found"
}
204: { "description": "Model deleted successfully" },
404: { "description": "Model not found" }
},
status_code = 204,
response_model = None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
@@ -173,14 +257,17 @@ async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model"""
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
model_type = model_type,
convert_dest_directory = dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
@@ -191,6 +278,53 @@ async def convert_model(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: { "description" : "paths retrieved successfully" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def list_ckpt_configs(
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.get(
"/sync",
operation_id="sync_to_config",
responses={
201: { "description": "synchronization successful" },
},
status_code = 201,
response_model = None
)
async def sync_to_config(
)->None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put(
"/merge/{base_model}",
@@ -210,17 +344,21 @@ async def merge_models(
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names}")
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
merged_model_name or "+".join(model_names),
alpha,
interp,
force)
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,

View File

@@ -54,7 +54,6 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
import torch
@@ -295,7 +294,6 @@ def invoke_cli():
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)

View File

@@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
truncate_long_prompts=True,
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),

View File

@@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image
from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
@@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
@@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
if i < -1 or i > 2:
raise ValueError('Control weights must be within -1 to 2 range')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
if v < -1 or v > 2:
raise ValueError('Control weights must be within -1 to 2 range')
return v
class Config:
schema_extra = {
@@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
}
@@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")

View File

@@ -5,6 +5,7 @@ from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from typing import Union
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
@@ -398,8 +399,8 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
@@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
@@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
**scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
)
def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data
def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
@@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
@@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
)
)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
@@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype)
latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
@@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

View File

@@ -1,55 +0,0 @@
from typing import Literal, Optional
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal, Union, cast
import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
type: Literal["upscale"] = "upscale"
class RealESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
# Inputs
image: Optional[ImageField] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
type: Literal["realesrgan"] = "realesrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
image = context.services.images.get_pil_image(self.image.image_name) # type: ignore
models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore
rrdbnet_model = None
netscale = None
model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth":
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_dir / model_path),
model=rrdbnet_model,
half=False,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create(
image=results[0][0],
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,

View File

@@ -200,7 +200,7 @@ class InvokeAISettings(BaseSettings):
type = get_args(get_type_hints(cls)['type'])[0]
field_dict = dict({type:dict()})
for name,field in self.__fields__.items():
if name in cls._excluded():
if name in cls._excluded_from_yaml():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self,name)
@@ -271,8 +271,13 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version']
# internal fields that shouldn't be exposed as command line options
return ['type','initconf']
@classmethod
def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
class Config:
env_file_encoding = 'utf-8'
@@ -361,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@@ -10,7 +10,6 @@ if TYPE_CHECKING:
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
@@ -34,7 +33,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__(
self,
@@ -50,7 +48,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
):
self.board_images = board_images
self.boards = boards
@@ -65,4 +62,3 @@ class InvocationServices:
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
ModelMerger,
MergeInterpolationMethod,
)
from invokeai.backend.model_management.model_search import FindModels
import torch
from invokeai.app.models.exceptions import CanceledException
@@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(
self
)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod
def convert_model(
self,
@@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@@ -228,9 +250,26 @@ class ModelManagerServiceBase(ABC):
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
as well.
"""
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path]=None):
"""
@@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
@@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
alpha = alpha,
interp = interp,
force = force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels(directory,self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)

View File

@@ -1,113 +0,0 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
AutoProcessor,
CLIPSegForImageSegmentation,
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
@@ -45,7 +43,6 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
@@ -226,64 +223,30 @@ def download_conversion_models():
# ---------------------------------------------
def download_realesrgan():
logger.info("Installing models from RealESRGAN...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
def download_gfpgan():
logger.info("Installing GFPGAN models...")
for model in (
[
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
logger.info("Installing RealESRGAN models...")
URLs = [
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description = "RealESRGAN_x4plus.pth",
),
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description = "RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth",
),
]
for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
# ---------------------------------------------
def download_codeformer():
logger.info("Installing CodeFormer model file...")
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
def download_clipseg():
logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception:
logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc())
def download_support_models():
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_conversion_models()
# -------------------------------------
@@ -666,7 +629,7 @@ def run_console_ui(
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
set_min_terminal_size(MIN_COLS, MIN_LINES)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
@@ -858,9 +821,9 @@ def main():
download_support_models()
if opt.skip_sd_weights:
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
elif models_to_download:
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download)
postscript(errors=errors)

View File

@@ -593,9 +593,12 @@ script, which will perform a full upgrade in place."""
config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)])
# TODO: revisit
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root,dest_root)

View File

@@ -71,8 +71,6 @@ class ModelInstallList:
class InstallSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
# scan_directory: Path = None
# autoscan_on_startup: bool=False
@dataclass
class ModelLoadInfo():
@@ -119,6 +117,7 @@ class ModelInstall(object):
# supplement with entries in models.yaml
installed_models = self.mgr.list_models()
for md in installed_models:
base = md['base_model']
model_type = md['model_type']
@@ -136,6 +135,12 @@ class ModelInstall(object):
)
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:')
for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]:
models = set()
for key, value in self.datasets.items():

View File

@@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
@@ -322,16 +323,7 @@ class ModelManager(object):
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
# TODO: metadata not found
# TODO: version check
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger
self.cache = ModelCache(
@@ -342,11 +334,41 @@ class ModelManager(object):
sequential_offload = sequential_offload,
logger = logger,
)
self._read_models(config)
def _read_models(self, config: Optional[DictConfig] = None):
if not config:
if self.config_path:
config = OmegaConf.load(self.config_path)
else:
return
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
def sync_to_config(self):
"""
Call this when `models.yaml` has been changed externally.
This will reinitialize internal data structures
"""
# Reread models directory; note that this will reinitialize the cache,
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(
self,
model_name: str,
@@ -527,7 +549,10 @@ class ModelManager(object):
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = []
for model_key in model_keys:
model_config = self.models[model_key]
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@@ -589,6 +614,7 @@ class ModelManager(object):
rmtree(str(model_path))
else:
model_path.unlink()
self.commit()
# LS: tested
def add_model(
@@ -645,11 +671,61 @@ class ModelManager(object):
config = model_config,
)
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
Rename or rebase a model.
'''
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type)
if new_key in self.models:
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
# clean up caches
old_model_cache = self._get_model_cache_path(old_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
rmtree(str(old_model_cache))
else:
old_model_cache.unlink()
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -676,14 +752,14 @@ class ModelManager(object):
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
@@ -823,6 +899,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self)->Dict[str, AddModelResult]:
'''
Scan the autoimport directory (if defined) and import new models, delete defunct models.
@@ -830,63 +907,41 @@ class ModelManager(object):
# avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)
self.installer = installer
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
def on_model_found(self, model: Path):
if model not in self.ignore:
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
def models_found(self):
return self.new_models_found
installer = ModelInstall(config = self.app_config,
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
scanned_dirs = set()
config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
for autodir in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]:
if autodir is None:
continue
self.logger.info(f'Scanning {autodir} for models to import')
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
except ValueError as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
except ValueError as e:
self.logger.warning(str(e))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed
known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]
}
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
return scanner.models_found()
def heuristic_import(self,
items_to_import: Set[str],
@@ -924,3 +979,4 @@ class ModelManager(object):
successfully_installed.update(installed)
self.commit()
return successfully_installed

View File

@@ -11,7 +11,7 @@ from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List, Union
from typing import List, Union, Optional
import invokeai.backend.util.logging as logger
@@ -74,6 +74,7 @@ class ModelMerger(object):
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
"""
@@ -85,7 +86,7 @@ class ModelMerger(object):
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
@@ -111,7 +112,7 @@ class ModelMerger(object):
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name

View File

@@ -0,0 +1,103 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return self.models_found

View File

@@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():

View File

@@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None)
class Config:

View File

@@ -1,8 +1,7 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
@@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
class ControlNetModelFormat(str, Enum):
@@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
model = None
for variant in ['fp16',None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model

View File

@@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel):
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main

View File

@@ -1,4 +0,0 @@
"""
Initialization file for the invokeai.backend.restoration package
"""
from .base import Restoration

View File

@@ -1,45 +0,0 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
def load_face_restore_models(
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
logger.info("GFPGAN Initialized")
else:
logger.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
logger.info("CodeFormer Initialized")
else:
logger.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_model_path):
from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path)
def load_codeformer(self):
from .codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
logger.info("ESRGAN Initialized")
return esrgan

View File

@@ -1,120 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from .codeformer_arch import CodeFormer
cf_class = CodeFormer
cf = cf_class(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(
url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True,
)
checkpoint = torch.load(checkpoint_path)["params_ema"]
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert("RGB")
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
face_helper = FaceRestoreHelper(
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
del output
torch.cuda.empty_cache()
except RuntimeError as error:
logger.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@@ -1,325 +0,0 @@
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from .vqgan_arch import *
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(
self,
dim_embd=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@@ -1,84 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
return None
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(self.globals.root_dir / 'models')
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@@ -1,118 +0,0 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process(
self,
extents: dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback=None,
prefix=None,
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_scheduler()
def wrapped_callback(img, seed, **kwargs):
preferred_seed = (
orig_opt.seed
if orig_opt.seed is not None and orig_opt.seed >= 0
else seed
)
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
result = self.generate.prompt2image(
opt.prompt,
seed=opt.seed or orig_opt.seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=extended_image.width,
height=extended_image.height,
init_img=extended_image,
strength=0.90,
image_callback=wrapped_callback if image_callback else None,
seam_size=opt.seam_size or 96,
seam_blur=opt.seam_blur or 16,
seam_strength=opt.seam_strength or 0.7,
seam_steps=20,
tile_size=32,
color_match=True,
force_outpaint=True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents: dict,
) -> Image:
"""
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
"""
image = self.image
for direction in extents:
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
logger.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)
return image
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
"""
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
"""
transposes = {
"right": ["ROTATE_90", "ROTATE_270"],
"bottom": ["ROTATE_180", "ROTATE_180"],
"left": ["ROTATE_270", "ROTATE_90"],
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self, image: Image, pixels: int) -> Image:
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
extended_img.paste(image, box=(0, pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel("A")
alpha.paste(0, (0, 0, extended_img.width, pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@@ -1,102 +0,0 @@
import math
import warnings
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback=None, prefix=None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img, seed, **kwargs):
image_callback(img, seed, use_prefix=prefix, **kwargs)
return self.generate.prompt2image(
prompt,
seed=seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=opt.width,
height=opt.height,
init_img=image,
strength=0.83,
image_callback=wrapped_callback,
prefix=prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [
1,
2,
], "Direction (-D) must have exactly one or two arguments."
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == "left":
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == "bottom":
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height // 2 if pixels is None else int(pixels)
assert (
0 < pixels < image.height
), "Direction (-D) pixels length must be in the range 0 - image.size"
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height // 10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == "left":
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == "bottom":
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@@ -1,104 +0,0 @@
import warnings
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(
self,
image: ImageType,
strength: float,
seed: str = None,
upsampler_scale: int = 2,
denoise_str: float = 0.75,
):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import sys
import traceback
logger.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler="realesrgan",
)
# Flip the channels back to RGB
res = Image.fromarray(output[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@@ -1,514 +0,0 @@
"""
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n_layers, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_d" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_d"]
)
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
return self.main(x)

View File

@@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x

View File

@@ -24,7 +24,7 @@ import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from accelerate.utils import set_seed, ProjectConfiguration
from diffusers import (
AutoencoderKL,
DDPMScheduler,
@@ -35,7 +35,6 @@ from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
from omegaconf import OmegaConf
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
@@ -47,6 +46,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@@ -132,7 +133,7 @@ def parse_args():
model_group.add_argument(
"--model",
type=str,
default="stable-diffusion-1.5",
default="sd-1/main/stable-diffusion-v1-5",
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
)
model_group.add_argument(
@@ -565,7 +566,6 @@ def do_textual_inversion_training(
checkpointing_steps: int = 500,
resume_from_checkpoint: Path = None,
enable_xformers_memory_efficient_attention: bool = False,
root_dir: Path = None,
hub_model_id: str = None,
**kwargs,
):
@@ -584,13 +584,17 @@ def do_textual_inversion_training(
logging_dir = output_dir / logging_dir
accelerator_config = ProjectConfiguration()
accelerator_config.logging_dir = logging_dir
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=report_to,
logging_dir=logging_dir,
project_config=accelerator_config,
)
model_manager = ModelManagerService(config,logger)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -628,46 +632,46 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
models_conf = OmegaConf.load(config.model_conf_path)
model_conf = models_conf.get(model, None)
assert model_conf is not None, f"Unknown model: {model}"
known_models = model_manager.model_names()
model_name = model.split('/')[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert (
model_conf.get("format", "diffusers") == "diffusers"
model_info['model_format'] == "diffusers"
), "This script only works with models of type 'diffusers'"
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
model_conf.get("path")
)
assert (
pretrained_model_name_or_path
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
pipeline_args = dict(cache_dir=config.cache_dir)
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
# Load tokenizer
pipeline_args = dict(local_files_only=True)
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
tokenizer_info.location, subfolder='tokenizer', **pipeline_args
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path,
text_encoder_info.location,
subfolder="text_encoder",
revision=revision,
**pipeline_args,
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path,
vae_info.location,
subfolder="vae",
revision=revision,
**pipeline_args,
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
unet_info.location,
subfolder="unet",
revision=revision,
**pipeline_args,
@@ -989,7 +993,7 @@ def do_textual_inversion_training(
save_full_model = not only_save_embeds
if save_full_model:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
unet_info.location,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,

View File

@@ -58,22 +58,29 @@ sd-1/main/waifu-diffusion:
recommended: False
sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
repo_id: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd:
repo_id: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
repo_id: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
repo_id: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
repo_id: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
repo_id: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
repo_id: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
repo_id: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge:
repo_id: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle:
@@ -84,6 +91,7 @@ sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA:

View File

@@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models)==0
if len(model_list) > 0:
max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
@@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
value=[
model_list.index(x)
for x in model_list
if self.all_models[x].installed
if (show_recommended and self.all_models[x].recommended) \
or self.all_models[x].installed
],
max_height=len(model_list)//columns + 1,
relx=4,
@@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
# pass
installer = ModelInstall(config, prediction_type_helper=helper)
if opt.add or opt.delete:
if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(
install_models = opt.add or [],
remove_models = opt.delete or []
@@ -696,7 +701,7 @@ def select_and_download_models(opt: Namespace):
# the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-model-install')
set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication(opt)
try:
installApp.run()
@@ -745,7 +750,7 @@ def main():
)
parser.add_argument(
"--list-models",
choices=["diffusers","loras","controlnets","tis"],
choices=[x.value for x in ModelType],
help="list installed models",
)
parser.add_argument(
@@ -773,7 +778,7 @@ def main():
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
if not (config.conf_path / 'models.yaml').exists():
if not config.model_conf_path.exists():
logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
)

View File

@@ -17,28 +17,20 @@ from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs
MIN_COLS = 130
MIN_COLS = 136
MIN_LINES = 45
# -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
def set_terminal_size(columns: int, lines: int):
ts = get_terminal_size()
width = max(columns,ts.columns)
height = max(lines,ts.lines)
OS = platform.uname().system
if OS == "Windows":
# The new Windows Terminal doesn't resize, so we relaunch in a CMD window.
# Would prefer to use execvpe() here, but somehow it is not working properly
# in the Windows 10 environment.
if 'IA_RELAUNCHED' not in os.environ:
args=['conhost']
args.extend([launch_command] if launch_command else [sys.argv[0]])
args.extend(sys.argv[1:])
os.environ['IA_RELAUNCHED'] = 'True'
os.execvp('conhost',args)
else:
_set_terminal_size_powershell(width,height)
pass
# not working reliably - ask user to adjust the window
#_set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height)
@@ -84,20 +76,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=None):
def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines:
return
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
set_terminal_size(cols, lines, launch_command)
# did it work?
term_cols, term_lines = get_terminal_size()
if term_cols < cols or term_lines < lines:
print(f'This window is too small for optimal display. For best results please enlarge it.')
input('After resizing, press any key to continue...')
set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider):
def translate_value(self):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-078526aa.js"></script>
<script type="module" crossorigin src="./assets/index-f1a5f9cf.js"></script>
</head>
<body dir="ltr">

View File

@@ -102,7 +102,8 @@
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
"imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
},
"gallery": {
"generations": "Generations",
@@ -118,7 +119,7 @@
"pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded",
"loadMore": "Load More",
"noImagesInGallery": "No Images In Gallery",
"noImagesInGallery": "No Images to Display",
"deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
@@ -342,6 +343,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",
@@ -396,8 +398,8 @@
"delete": "Delete",
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location",
@@ -408,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
@@ -419,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@@ -445,7 +449,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",
@@ -528,7 +533,7 @@
"hidePreview": "Hide Preview",
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode",
"clipSkip": "Clip Skip",
"clipSkip": "CLIP Skip",
"aspectRatio": "Ratio"
},
"settings": {
@@ -593,7 +598,11 @@
"metadataLoadFailed": "Failed to load metadata",
"initialImageSet": "Initial Image Set",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image"
"initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
"tooltip": {
"feature": {
@@ -674,5 +683,11 @@
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
}
}

View File

@@ -343,6 +343,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",
@@ -397,8 +398,8 @@
"delete": "Delete",
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location",
@@ -409,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
@@ -420,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@@ -446,7 +449,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",
@@ -599,7 +603,6 @@
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
"tooltip": {
"feature": {

View File

@@ -9,9 +9,9 @@ import { theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = {
children: ReactNode;
@@ -35,8 +35,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction;
}, [direction]);
const mantineTheme = useMantineTheme();
return (
<MantineProvider withGlobalStyles theme={mantineTheme}>
<MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>

View File

@@ -1,5 +1,11 @@
import { createAction } from '@reduxjs/toolkit';
import { isLoadingChanged } from 'features/gallery/store/gallerySlice';
import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
INITIAL_IMAGE_LIMIT,
isLoadingChanged,
} from 'features/gallery/store/gallerySlice';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { startAppListening } from '..';
export const appStarted = createAction('app/appStarted');
@@ -13,6 +19,25 @@ export const addAppStartedListener = () => {
) => {
cancelActiveListeners();
unsubscribe();
// fill up the gallery tab with images
await dispatch(
receivedPageOfImages({
categories: IMAGE_CATEGORIES,
is_intermediate: false,
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
})
);
// fill up the assets tab with images
await dispatch(
receivedPageOfImages({
categories: ASSETS_CATEGORIES,
is_intermediate: false,
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
})
);
dispatch(isLoadingChanged(false));
},

View File

@@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
@@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false;
}
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];

View File

@@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}
// TODO: handle incompatible controlnet; pending model manager support
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
},
});
};

View File

@@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
@@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => {
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
},
});
};

View File

@@ -100,10 +100,11 @@ export const store = configureStore({
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) {
// don't log api actions, with manual cache updates they are extremely noisy
return false;
}
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions

View File

@@ -1,5 +1,5 @@
import {
CONTROLNET_MODELS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
@@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: {
initial: number;

View File

@@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
{imageDTO && (
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}

View File

@@ -1,11 +1,13 @@
import { Flex } from '@chakra-ui/react';
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
@@ -14,7 +16,7 @@ export function IAIFormItemWrapper({
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}

View File

@@ -0,0 +1,44 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
const { ...rest } = props;
const {
base50,
base100,
base200,
base300,
base800,
base700,
base900,
accent500,
accent300,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
})}
{...rest}
/>
);
}

View File

@@ -1,38 +1,26 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & {
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode();
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
@@ -52,100 +40,25 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
[dispatch]
);
const styles = useMantineMultiSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
styles={styles}
{...rest}
/>
</Tooltip>

View File

@@ -0,0 +1,95 @@
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={styles}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSearchableSelect);

View File

@@ -1,10 +1,7 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
export type IAISelectDataType = {
value: string;
@@ -12,161 +9,30 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = SelectProps & {
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { tooltip, inputRef, label, disabled, ...rest } = props;
const { colorMode } = useColorMode();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
styles={styles}
{...rest}
/>
</Tooltip>

View File

@@ -28,7 +28,7 @@ import {
useState,
} from 'react';
const numberStringRegex = /^-?(0\.)?\.?$/;
export const numberStringRegex = /^-?(0\.)?\.?$/;
interface Props extends Omit<NumberInputProps, 'onChange'> {
label?: string;

View File

@@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = {
label?: string;
value: number;
@@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps}
>
{label && (
<FormLabel {...sliderFormLabelProps} mb={-1}>
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label}
</FormLabel>
)}
@@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
transform: 'translateX(-50%)',
}}
{...sliderMarkProps}
>

View File

@@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector],
@@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
}
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good
return { isReady, reasonsWhyNotReady };
},

View File

@@ -24,7 +24,7 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
@@ -213,7 +213,7 @@ const IAICanvasToolbar = () => {
}}
>
<Box w={24}>
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer}
data={LAYER_NAMES_DICT}

View File

@@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetAdded,
controlNetDuplicated,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
@@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
@@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = {
controlNet: ControlNetConfig;
controlNetId: string;
};
const ControlNet = (props: ControlNetProps) => {
const {
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [dispatch, props.controlNet]);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
@@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column',
gap: 2,
p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base',
position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<Flex sx={{ gap: 2 }}>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch
tooltip="Toggle"
aria-label="Toggle"
tooltip={'Toggle this ControlNet'}
aria-label={'Toggle this ControlNet'}
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
@@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
<ParamControlNetModel controlNetId={controlNetId} />
</Box>
<IAIIconButton
size="sm"
@@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/>
<IAIIconButton
size="sm"
aria-label="Show All Options"
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded}
variant="link"
icon={
<ChevronUpIcon
sx={{
boxSize: 4,
color: mode('base.700', 'base.300')(colorMode),
color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}}
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
@@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5,
h: 1.5,
borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
flexDir: 'column',
gap: 3,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex
sx={{
flexDir: 'column',
gap: 3,
h: 28,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight controlNetId={controlNetId} />
<ParamControlNetBeginEnd controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<Box mt={2}>
<ControlNetImagePreview
controlNet={props.controlNet}
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 28,
w: 28,
aspectRatio: '1/1',
mt: 3,
}}
>
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex>
)}
</Flex>
<Box mt={2}>
<ParamControlNetControlMode controlNetId={controlNetId} />
</Box>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
<ControlNetProcessorComponent controlNetId={controlNetId} />
</>
)}
</Flex>

View File

@@ -5,42 +5,57 @@ import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
import { controlNetImageChanged } from '../store/controlNetSlice';
type Props = {
controlNet: ControlNetConfig;
controlNetId: string;
height: SystemStyleObject['h'];
};
const ControlNetImagePreview = (props: Props) => {
const { height } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const { height, controlNetId } = props;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height,
alignItems: 'center',
justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}}
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage}
postUploadAction={postUploadAction}
resetTooltip="Reset Control Image"
@@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)}

View File

@@ -1,10 +1,13 @@
import { memo } from 'react';
import { RequiredControlNetProcessorNode } from '../store/types';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor';
@@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
<HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}

View File

@@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
shouldAutoConfig: boolean;
};
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]);
@@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor"
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
);
};

View File

@@ -1,5 +1,4 @@
import {
ChakraProps,
FormControl,
FormLabel,
HStack,
@@ -10,34 +9,41 @@ import {
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback(
(v: number[]) => {
@@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]);
return (
<FormControl>
<FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
@@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1}
step={0.01}
minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
@@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
{!mini && (
<>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
...SLIDER_MARK_STYLES,
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
>
100%
</RangeSliderMark>
</>
)}
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
100%
</RangeSliderMark>
</RangeSlider>
</HStack>
</FormControl>

View File

@@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
@@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
@@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}

View File

@@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
formControlProps={{
width: '100%',
}}
/>
);
};

View File

@@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModelName;
};
const selector = createSelector(configSelector, (config) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
label: m.label,
value: m.type,
})).filter(
(d) =>
!config.sd.disabledControlNetModels.includes(
d.value as ControlNetModelName
)
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
return controlNetModels;
});
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { data: controlNetModels } = useGetControlNetModelsQuery();
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback(
(val: string | null) => {
// TODO: do not cast
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
data={controlNetModels}
value={model}
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged}
disabled={!isReady}
tooltip={model}
disabled={isBusy || !isEnabled}
tooltip={selectedModel?.description}
/>
);
};

View File

@@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
} from 'common/components/IAIMantineSearchableSelect';
import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ControlNetProcessorType } from '../../store/types';
import { FormControl, FormLabel } from '@chakra-ui/react';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const selector = createSelector(
@@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {
@@ -72,12 +84,12 @@ const ParamControlNetProcessorSelect = (
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label="Processor"
value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors}
onChange={handleProcessorTypeChanged}
disabled={!isReady}
disabled={isBusy || !isEnabled}
/>
);
};

View File

@@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = {
controlNetId: string;
weight: number;
mini?: boolean;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
@@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return (
<IAISlider
isDisabled={!isEnabled}
label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight}
onChange={handleWeightChanged}
min={-1}
max={1}
min={0}
max={2}
step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
withSliderMarks
sliderMarks={[0, 1, 2]}
/>
);
};

View File

@@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = {
controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
};
const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback(
@@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return (
<ProcessorWrapper>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="Low Threshold"
value={low_threshold}
onChange={handleLowThresholdChanged}
@@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks
/>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="High Threshold"
value={high_threshold}
onChange={handleHighThresholdChanged}

View File

@@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
};
const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="F"
@@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = {
controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
};
const HedPreprocessor = (props: HedProcessorProps) => {
const {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
@@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
};
const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = {
controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
};
const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Coarse"
isChecked={coarse}
onChange={handleCoarseChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
};
const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback(
(v: number) => {
@@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Min Confidence"
@@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
};
const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback(
(v: number) => {
@@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="bg_th"
@@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
};
const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
};
const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
};
const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
};
const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
@@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = {
controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
};
const ZoeDepthProcessor = (props: Props) => {

View File

@@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
},
};
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
depth: 'midas_depth_image_processor',
bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
};
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
@@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
@@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
model: ControlNetModelParam;
}>
) => {
const { controlNetId, model } = action.payload;
@@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery
// Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
@@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
});
});
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
});
@@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const {
isControlNetEnabledToggled,
controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,

View File

@@ -9,7 +9,7 @@ import {
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
@@ -106,7 +106,7 @@ const ParamEmbeddingPopover = (props: Props) => {
</Text>
</Flex>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
inputRef={inputRef}
autoFocus
placeholder={'Add Embedding'}

Some files were not shown because too many files have changed in this diff Show More