Compare commits

..

75 Commits

Author SHA1 Message Date
psychedelicious
4c37c48b8c fix(installer): use extra index url when updating
If we don't include this, on updating, we will always get the CPU torch/torchvision/xformers.
2023-11-15 20:43:11 +11:00
psychedelicious
0cfe2ccd9d fix: pin torch and torchvision exactly
When upgrading the app with `--extra-index-url`, torch is updated to 2.1.1, which is not an official release.

This breaks all sorts of stuff. Pin the versions exactly to avoid this.

Also pin transformers exactly while we are here.
2023-11-15 20:39:57 +11:00
Millun Atluri
b6f356f067 Change stylecheck name from "black" to "ruff" (#5090)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because: it is trivial

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No


## Description

After the switch to the "ruff" linter, I noticed that the stylecheck
workflow is still described as "black" in the action logs. This small PR
should fix the issue.
2023-11-15 08:29:41 +11:00
Lincoln Stein
a4f1db7c02 change stylecheck name from "black" to "ruff" 2023-11-14 11:06:10 -05:00
psychedelicious
21206bafcf chore: bump pydantic and fastapi
No breaking changes for us.

Pydantic is working on its own faster JSON parser, `jiter`, and 2.5.0 starts bringing this in. See https://github.com/pydantic/jiter

There are a number of other bugfixes and minor changes in this version of pydantic.

The FastAPI update is mostly internal but let's stay up to date.
2023-11-14 14:34:14 +11:00
Millun Atluri
a047bad391 Revert torch to use cu121 (#5091)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-14 13:47:51 +11:00
Millun Atluri
909afc266e Update 010_INSTALL_AUTOMATED.md 2023-11-13 20:28:00 -05:00
Millun Atluri
4039dd148d Update 030_INSTALL_CUDA_AND_ROCM.md 2023-11-13 20:28:00 -05:00
Millun Atluri
ea0f8b8791 Update 020_INSTALL_MANUAL.md 2023-11-13 20:28:00 -05:00
Millun Atluri
f412582d60 Update README.md to cu121 2023-11-13 20:28:00 -05:00
Millun Atluri
c5672adb6b Update 070_INSTALL_XFORMERS.md 2023-11-13 20:28:00 -05:00
Millun Atluri
0e5c3a641a Revert torch to use cu121 2023-11-13 20:28:00 -05:00
Millun Atluri
9015e72e1e Update README.md to include M3 (#5092)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [x] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [x] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-14 12:24:57 +11:00
Millun Atluri
6b05d27c7a Update 040_INSTALL_DOCKER.md 2023-11-14 12:22:46 +11:00
Millun Atluri
19d0673085 Update 010_INSTALL_AUTOMATED.md 2023-11-14 12:22:08 +11:00
Kieran Klaassen
048b4fe7e8 Update README.md to include M3 2023-11-13 19:11:31 -06:00
psychedelicious
e8b83fecff fix(backend): apply clip skip after lora
This handles LoRAs that attempt to modify layers skipped by CLIP Skip.
2023-11-14 11:30:15 +11:00
Lincoln Stein
8883ecb2bf Model Manager Refactor Phase 1 - SQL-based config storage (#5039)
## What type of PR is this? (check all applicable)

- [X] Refactor


## Have you discussed this change with the InvokeAI team?
- [X] Extensively
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

As discussed with @psychedelicious and @RyanJDick, this is the first
phase of the model manager refactor. In this phase, I've added support
for storing model configuration information the `invokeai.db` SQL3
database. All the code is separate from the original model manager, so
for the time being the frontend is still using the original YAML-based
configuration, so the web app still works.

To keep things clean, I've added a new FastAPI route called
`model_records` which can add, update, retrieve and delete model
records.

The architecture is described in the first section of
`docs/contributing/MODEL_MANAGER.md`.

## QA Instructions, Screenshots, Recordings

There is a pytest for the model sql storage backend in
`tests/backend/model_manager_2/test_model_storage_sql.py`.

To populate `invokeai.db` with models from your current `models.yaml`,
do the following:

1. Stop the running server
2. Back up `invokeai.db`
3. Run `pip install -e .` to install the command used in the next step.
4. Run `invokeai-migrate-models-to-db`

This will iterate through `models.yaml` and create equivalent database
entries in the `model_config` table of `invokeai.db`. Only the models
named in the yaml file will be migrated, so anything that is autoloaded
will be ignored.

Note that in order to get the `model_records` router to be recognized by
the swagger API, I had to rebuild the frontend. Not sure why this was
necessary and would appreciate a pointer on a less radical way to do
this.

## Added/updated tests?

- [X] Yes
- [ ] No
2023-11-13 18:59:25 -05:00
Lincoln Stein
2f97f1d6d5 Merge branch 'main' into refactor/model-manager-2 2023-11-13 18:21:16 -05:00
Lincoln Stein
73d6cc824b Update Pytorch to ~2.1.0 in the installer script (#5089)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because it's required

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No, not necessary


## Description

We use Pytorch ~2.1.0 as a dependency for InvokeAI, but the installer
still installs 2.0.1 first until Invoke AIs dependencies kick in which
causes it to get deleted anyway and replaced with 2.1.0. This is
unnecessary and probably not wanted.

Fixed the dependencies for the installation script to install Pytorch
~2.1.0 to begin with.

P.s. Is there any reason why "torchmetrics==0.11.4" is pinned? What is
the reason for that? Does that change with Pytorch 2.1? It seems to work
since we use it already. It would be nice to know the reason.

Greetings

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-13 18:20:36 -05:00
Lincoln Stein
acc0a29dca fixed ruff formatting issues 2023-11-13 18:15:17 -05:00
Lincoln Stein
38c1436f02 resolve conflicts; blackify 2023-11-13 18:12:45 -05:00
Lincoln Stein
efbdb75568 implement psychedelicious recommendations as of 13 November 2023-11-13 17:05:01 -05:00
psychedelicious
8929495aeb fix(test): remove unused assignment to value 2023-11-14 08:08:23 +11:00
psychedelicious
428f0b265f feat(api): add log stmt to update_model_record route 2023-11-14 08:06:35 +11:00
psychedelicious
7daee41ad2 fix(api): remove unused ModelsListValidator 2023-11-14 08:01:44 +11:00
psychedelicious
7cdd7b6ad7 feat(api): simplifiy list_model_records handler 2023-11-14 08:00:21 +11:00
psychedelicious
bc64cde6f9 chore: ruff lint 2023-11-14 07:57:07 +11:00
psychedelicious
4465f97cdf Merge branch 'main' into refactor/model-manager-2 2023-11-14 07:51:57 +11:00
Wubbbi
fface2cda7 Update torch to ~2.1.0 in the installer 2023-11-13 17:30:51 +01:00
blessedcoolant
7fcb8959fb chore(ui): cleanup (#5084)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Bit of a cleanup. 

[chore(ui): delete unused
files](5eaea9dd64)

[feat(ui): add eslint rule
react/jsx-no-bind](3a0ec635c9)

This rule enforces no arrow functions in component props. In practice,
it means all functions passed as component props must be wrapped in
`useCallback()`.

This is a performance optimization to prevent unnecessary rerenders.

The rule is added and all violations have been fixed, whew!

[chore(ui): move useCopyImageToClipboard to
common/hooks/](f2d26a3a3c)

[chore(ui): move MM components & store to
features/](bb52861896)

Somehow they had ended up in `features/ui/tabs` which isn't right

## QA Instructions, Screenshots, Recordings

UI should still work.

It builds successfully, and I tested things out - looks good to me.
2023-11-13 13:22:41 +05:30
psychedelicious
dcf0dc4274 Merge branch 'main' into chore/ui/cleanup 2023-11-13 16:33:08 +11:00
psychedelicious
bb52861896 chore(ui): move MM components & store to features/
Somehow they had ended up in `features/ui/tabs` which isn't right
2023-11-13 16:32:03 +11:00
psychedelicious
f2d26a3a3c chore(ui): move useCopyImageToClipboard to common/hooks/ 2023-11-13 16:23:46 +11:00
psychedelicious
04d8f2dfea fix(backend): fix controlnet zip len
Do not use `strict=True` when scaling controlnet conditioning.

When using `guess_mode` (e.g. `more_control` or `more_prompt`), `down_block_res_samples` and `scales` are zipped.

These two objects are of different lengths, so using zip's strict mode raises an error.

In testing, `len(scales) === len(down_block_res_samples) + 1`.

It appears this behaviour is intentional, as the final "extra" item in `scales` is used immediately afterwards.
2023-11-13 15:45:03 +11:00
Millun Atluri
355d4cf4e2 Update Accelerate to 0.24.X (#5075)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because: This is just housekeeping

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No, not needed


## Description

Update Accelerate to the most recent version. No breaking changes.
Tested for 1 week in productive use now.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-13 14:20:05 +11:00
Millun Atluri
a3a828779a Merge branch 'main' into update-accelerate 2023-11-13 14:10:53 +11:00
Lincoln Stein
8c71ff37ae Update config.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-12 19:03:39 -05:00
psychedelicious
ddb65e6034 Merge branch 'main' into chore/ui/cleanup 2023-11-13 10:53:04 +11:00
psychedelicious
3a0ec635c9 feat(ui): add eslint rule react/jsx-no-bind
This rule enforces no arrow functions in component props. In practice, it means all functions passed as component props must be wrapped in `useCallback()`.

This is a performance optimization to prevent unnecessary rerenders.

The rule is added and all violations have been fixed, whew!
2023-11-13 10:01:14 +11:00
Lincoln Stein
8afe517204 add note about discriminated union and Body() issue; blackified 2023-11-12 16:50:05 -05:00
psychedelicious
5eaea9dd64 chore(ui): delete unused files 2023-11-13 08:43:27 +11:00
Lincoln Stein
ef8dcf5fae blackify 2023-11-12 14:20:32 -05:00
Lincoln Stein
024a156114 isort 2023-11-11 13:58:36 -05:00
Lincoln Stein
7ea2a135f1 remove dangling import 2023-11-11 12:24:58 -05:00
Lincoln Stein
af2264b6eb implement workaround for FastAPI and discriminated unions in Body parameter 2023-11-11 12:22:38 -05:00
Wubbbi
41bf9ec4a3 Update Accelerate to 0.24.X 2023-11-11 09:46:23 +01:00
Lincoln Stein
2b36565e9e awkward workaround for double-Annotated in model_record route 2023-11-10 21:32:44 -05:00
Lincoln Stein
f2c3b7c317 Merge branch 'refactor/model-manager-2' of github.com:invoke-ai/InvokeAI into refactor/model-manager-2 2023-11-10 19:47:01 -05:00
Lincoln Stein
67751a01ab remove unused import 2023-11-10 19:25:05 -05:00
Lincoln Stein
cb8cdefd59 Merge branch 'main' into refactor/model-manager-2 2023-11-10 19:24:19 -05:00
Lincoln Stein
f1c846ba5c blackify 2023-11-10 19:14:29 -05:00
Lincoln Stein
3a6ba236f5 replace _class_map in ModelConfigFactory with a nested discriminated union 2023-11-10 19:14:15 -05:00
Lincoln Stein
bd56e9bc81 remove cruft code from router 2023-11-10 18:49:25 -05:00
Lincoln Stein
b55fc2935e resolve conflicts with commits done on github 2023-11-10 18:26:48 -05:00
Lincoln Stein
0544917161 multiple small fixes suggested in reviews from psychedelicious and ryan 2023-11-10 18:25:37 -05:00
Lincoln Stein
1161dfe055 Update invokeai/app/api/routers/model_records.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-10 18:24:55 -05:00
Lincoln Stein
433f347d7e Update invokeai/app/api/routers/model_records.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-10 18:22:54 -05:00
Lincoln Stein
33a412a24f Update invokeai/backend/model_manager/config.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-10 18:21:38 -05:00
Lincoln Stein
9316534d97 Update invokeai/app/services/model_records/model_records_sql.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-10 17:58:15 -05:00
Lincoln Stein
fdaa661245 revert frontend dist files to main 2023-11-10 17:57:18 -05:00
Lincoln Stein
f1c195afb7 Merge branch 'main' into refactor/model-manager-2 2023-11-10 17:54:28 -05:00
Lincoln Stein
3b363d0258 fix flake8 lint check failures 2023-11-08 16:52:46 -05:00
Lincoln Stein
36e0faea6b blackify 2023-11-08 16:47:03 -05:00
Lincoln Stein
927f8a66e6 Merge branch 'main' into refactor/model-manager-2 2023-11-08 16:46:08 -05:00
Lincoln Stein
eebc0e7315 Merge branch 'refactor/model-manager-2' of github.com:invoke-ai/InvokeAI into refactor/model-manager-2 2023-11-08 16:45:29 -05:00
Lincoln Stein
6b173cc66f multiple small stylistic changes requested by reviewers 2023-11-08 16:45:26 -05:00
Lincoln Stein
b4732a7308 Update invokeai/app/services/model_records/model_records_base.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-08 13:50:40 -05:00
Lincoln Stein
344a56327a Update invokeai/app/services/model_records/model_records_base.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-08 13:50:01 -05:00
Lincoln Stein
ce22c0fbaa sync pydantic and sql field names; merge routes 2023-11-06 18:08:57 -05:00
Lincoln Stein
55f8865524 Merge branch 'main' into refactor/model-manager-2 2023-11-05 21:45:26 -05:00
Lincoln Stein
2d051559d1 fix flake8 complaints 2023-11-05 21:45:08 -05:00
Lincoln Stein
db9cef0092 re-run isort 2023-11-04 23:50:07 -04:00
Lincoln Stein
72c34aea75 added add_model_record and get_model_record to router api 2023-11-04 23:42:44 -04:00
Lincoln Stein
edeea5237b add sql-based model config store and api 2023-11-04 23:03:26 -04:00
200 changed files with 4331 additions and 3364 deletions

View File

@@ -6,7 +6,7 @@ on:
branches: main
jobs:
black:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

View File

@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
_For Linux with an AMD GPU:_
@@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
_For Macintoshes, either Intel or M1/M2/M3:_
```sh
pip install InvokeAI --use-pep517

File diff suppressed because it is too large Load Diff

View File

@@ -179,7 +179,7 @@ experimental versions later.
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
or CPU (no graphics acceleration). On Windows, you'll have the
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
graphics acceleration without installing additional drivers. If you
are unsure what GPU you are using, you can ask the installer to
guess.
@@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
pip install xformers
```

View File

@@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
=== "ROCm (AMD)"
@@ -327,7 +327,7 @@ installation protocol (important!)
=== "CUDA (NVidia)"
```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
=== "ROCm (AMD)"
@@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai
conda create -n invokeai python=3.10
conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web
```

View File

@@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/cu118` as described in the [Manual
https://download.pytorch.org/whl/cu121` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm

View File

@@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2): You
Developers on Apple silicon (M1/M2/M3): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your

View File

@@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands:
```sh
pip install xformers~=0.0.19
pip install xformers~=0.0.22
pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output
```
@@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
following:
```sh
xFormers 0.0.20
xFormers 0.0.22
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
@@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
is_functorch_available: False
pytorch.version: 2.0.1+cu118
pytorch.version: 2.1.0+cu121
pytorch.cuda: available
gpu.compute_capability: 8.9
gpu.name: NVIDIA GeForce RTX 4070
build.info: available
build.cuda_version: 1108
build.python_version: 3.10.11
build.torch_version: 2.0.1+cu118
build.torch_version: 2.1.0+cu121
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
@@ -92,33 +92,22 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe.
#### 1. Install CUDA Toolkit 11.8
#### 1. Install CUDA Toolkit 12.1
You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
by NVIDIA itself. Go to [CUDA Toolkit 12.1
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the
last step.
This will provide you with a recipe for downloading and running a
install shell script that will install the toolkit and drivers. For
example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is:
install shell script that will install the toolkit and drivers.
```
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date
installer for your system.
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries
@@ -133,7 +122,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")'
```
If it prints __1.13.1+cu118__ you're good. If not, you can install the
If it prints __2.1.0+cu121__ you're good. If not, you can install the
most up to date libraries with this command:
```sh

View File

@@ -244,7 +244,7 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.0.0",
"torch~=2.1.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"--force-reinstall",
@@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@@ -24,6 +24,7 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@@ -85,6 +86,7 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@@ -111,6 +113,7 @@ class ApiDependencies:
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@@ -0,0 +1,164 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union

View File

@@ -43,6 +43,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
board_images,
boards,
images,
model_records,
models,
session_queue,
sessions,
@@ -106,6 +107,7 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")

View File

@@ -112,10 +112,11 @@ class CompelInvocation(BaseInvocation):
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@@ -234,10 +235,11 @@ class SDXLPromptInvocationBase:
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,

View File

@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@@ -49,6 +50,7 @@ class InvocationServices:
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@@ -76,6 +78,7 @@ class InvocationServices:
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@@ -101,6 +104,7 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@@ -0,0 +1,8 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,
ModelRecordServiceBase,
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401

View File

@@ -0,0 +1,169 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated path."""
pass
@abstractmethod
def search_by_hash(
self,
hash: str,
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated original hash."""
pass
@abstractmethod
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
if len(model_configs) == 0:
raise UnknownModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@@ -0,0 +1,397 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Implementation of the ModelRecordServiceBase API
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreSQL(sqlite_db)
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
type='embedding',
format='embedding_file',
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelConfigFactory,
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results

View File

@@ -0,0 +1,323 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
)
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
return model

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from imohash import hashfile
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
def __init__(self):
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
try:
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
def main():
MigrateModelYamlToDb().migrate()
if __name__ == "__main__":
main()

View File

@@ -748,7 +748,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = scales * conditioning_scale
down_block_res_samples = [
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:

View File

@@ -5,6 +5,7 @@ import math
import multiprocessing as mp
import os
import re
import warnings
from collections import abc
from inspect import isfunction
from pathlib import Path
@@ -14,8 +15,10 @@ from threading import Thread
import numpy as np
import requests
import torch
from diffusers import logging as diffusers_logging
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
@@ -379,3 +382,21 @@ class Chdir(object):
def __exit__(self, *args):
os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -90,6 +90,14 @@ def get_extras():
pass
return extras
def get_extra_index() -> str:
# parsed_version.local for torch is the platform + version, eg 'cu121' or 'rocm5.6'
local = pkg_resources.get_distribution("torch").parsed_version.local
if local and 'cu' in local:
return "--extra-index-url https://download.pytorch.org/whl/cu121"
if local and 'rocm' in local:
return "--extra-index-url https://download.pytorch.org/whl/rocm5.6"
return ""
def main():
versions = get_versions()
@@ -122,14 +130,15 @@ def main():
branch = Prompt.ask("Enter an InvokeAI branch name")
extras = get_extras()
extra_index_url = get_extra_index()
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
if release:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade {extra_index_url}'
elif tag:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade {extra_index_url}'
else:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade {extra_index_url}'
print("")
print("")
if os.system(cmd) == 0:

View File

@@ -24,6 +24,7 @@ module.exports = {
root: true,
rules: {
curly: 'error',
'react/jsx-no-bind': ['error', { allowBind: true }],
'react/jsx-curly-brace-presence': [
'error',
{ props: 'never', children: 'never' },

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,280 @@
import{w as s,ia as T,v as l,a2 as I,ib as R,ae as V,ic as z,id as j,ie as D,ig as F,ih as G,ii as W,ij as K,aG as H,ik as U,il as Y}from"./index-54a1ea80.js";import{M as Z}from"./MantineProvider-17a58e64.js";var P=String.raw,E=P`
:root,
:host {
--chakra-vh: 100vh;
}
@supports (height: -webkit-fill-available) {
:root,
:host {
--chakra-vh: -webkit-fill-available;
}
}
@supports (height: -moz-fill-available) {
:root,
:host {
--chakra-vh: -moz-fill-available;
}
}
@supports (height: 100dvh) {
:root,
:host {
--chakra-vh: 100dvh;
}
}
`,B=()=>s.jsx(T,{styles:E}),J=({scope:e=""})=>s.jsx(T,{styles:P`
html {
line-height: 1.5;
-webkit-text-size-adjust: 100%;
font-family: system-ui, sans-serif;
-webkit-font-smoothing: antialiased;
text-rendering: optimizeLegibility;
-moz-osx-font-smoothing: grayscale;
touch-action: manipulation;
}
body {
position: relative;
min-height: 100%;
margin: 0;
font-feature-settings: "kern";
}
${e} :where(*, *::before, *::after) {
border-width: 0;
border-style: solid;
box-sizing: border-box;
word-wrap: break-word;
}
main {
display: block;
}
${e} hr {
border-top-width: 1px;
box-sizing: content-box;
height: 0;
overflow: visible;
}
${e} :where(pre, code, kbd,samp) {
font-family: SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 1em;
}
${e} a {
background-color: transparent;
color: inherit;
text-decoration: inherit;
}
${e} abbr[title] {
border-bottom: none;
text-decoration: underline;
-webkit-text-decoration: underline dotted;
text-decoration: underline dotted;
}
${e} :where(b, strong) {
font-weight: bold;
}
${e} small {
font-size: 80%;
}
${e} :where(sub,sup) {
font-size: 75%;
line-height: 0;
position: relative;
vertical-align: baseline;
}
${e} sub {
bottom: -0.25em;
}
${e} sup {
top: -0.5em;
}
${e} img {
border-style: none;
}
${e} :where(button, input, optgroup, select, textarea) {
font-family: inherit;
font-size: 100%;
line-height: 1.15;
margin: 0;
}
${e} :where(button, input) {
overflow: visible;
}
${e} :where(button, select) {
text-transform: none;
}
${e} :where(
button::-moz-focus-inner,
[type="button"]::-moz-focus-inner,
[type="reset"]::-moz-focus-inner,
[type="submit"]::-moz-focus-inner
) {
border-style: none;
padding: 0;
}
${e} fieldset {
padding: 0.35em 0.75em 0.625em;
}
${e} legend {
box-sizing: border-box;
color: inherit;
display: table;
max-width: 100%;
padding: 0;
white-space: normal;
}
${e} progress {
vertical-align: baseline;
}
${e} textarea {
overflow: auto;
}
${e} :where([type="checkbox"], [type="radio"]) {
box-sizing: border-box;
padding: 0;
}
${e} input[type="number"]::-webkit-inner-spin-button,
${e} input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none !important;
}
${e} input[type="number"] {
-moz-appearance: textfield;
}
${e} input[type="search"] {
-webkit-appearance: textfield;
outline-offset: -2px;
}
${e} input[type="search"]::-webkit-search-decoration {
-webkit-appearance: none !important;
}
${e} ::-webkit-file-upload-button {
-webkit-appearance: button;
font: inherit;
}
${e} details {
display: block;
}
${e} summary {
display: list-item;
}
template {
display: none;
}
[hidden] {
display: none !important;
}
${e} :where(
blockquote,
dl,
dd,
h1,
h2,
h3,
h4,
h5,
h6,
hr,
figure,
p,
pre
) {
margin: 0;
}
${e} button {
background: transparent;
padding: 0;
}
${e} fieldset {
margin: 0;
padding: 0;
}
${e} :where(ol, ul) {
margin: 0;
padding: 0;
}
${e} textarea {
resize: vertical;
}
${e} :where(button, [role="button"]) {
cursor: pointer;
}
${e} button::-moz-focus-inner {
border: 0 !important;
}
${e} table {
border-collapse: collapse;
}
${e} :where(h1, h2, h3, h4, h5, h6) {
font-size: inherit;
font-weight: inherit;
}
${e} :where(button, input, optgroup, select, textarea) {
padding: 0;
line-height: inherit;
color: inherit;
}
${e} :where(img, svg, video, canvas, audio, iframe, embed, object) {
display: block;
}
${e} :where(img, video) {
max-width: 100%;
height: auto;
}
[data-js-focus-visible]
:focus:not([data-focus-visible-added]):not(
[data-focus-visible-disabled]
) {
outline: none;
box-shadow: none;
}
${e} select::-ms-expand {
display: none;
}
${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};

File diff suppressed because one or more lines are too long

View File

@@ -19,7 +19,7 @@ import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import queueReducer from 'features/queue/store/queueSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicMiddlewares from 'redux-dynamic-middlewares';

View File

@@ -8,7 +8,14 @@ import {
forwardRef,
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import {
cloneElement,
memo,
ReactElement,
ReactNode,
useCallback,
useRef,
} from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
@@ -38,15 +45,15 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const handleAccept = () => {
const handleAccept = useCallback(() => {
acceptCallback();
onClose();
};
}, [acceptCallback, onClose]);
const handleCancel = () => {
const handleCancel = useCallback(() => {
cancelCallback && cancelCallback();
onClose();
};
}, [cancelCallback, onClose]);
return (
<>

View File

@@ -1,43 +0,0 @@
import { Box, Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaExclamation } from 'react-icons/fa';
const IAIErrorLoadingImageFallback = () => {
return (
<Box
sx={{
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
height: 'full',
width: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
bg: 'base.100',
color: 'base.500',
_dark: {
color: 'base.700',
bg: 'base.850',
},
}}
>
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
</Flex>
</Box>
);
};
export default memo(IAIErrorLoadingImageFallback);

View File

@@ -1,8 +0,0 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@@ -1,15 +0,0 @@
import { FormErrorMessage, FormErrorMessageProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormErrorMessageProps = FormErrorMessageProps & {
children: ReactNode | string;
};
export default function IAIFormErrorMessage(props: IAIFormErrorMessageProps) {
const { children, ...rest } = props;
return (
<FormErrorMessage color="error.400" {...rest}>
{children}
</FormErrorMessage>
);
}

View File

@@ -1,15 +0,0 @@
import { FormHelperText, FormHelperTextProps } from '@chakra-ui/react';
import { ReactNode } from 'react';
type IAIFormHelperTextProps = FormHelperTextProps & {
children: ReactNode | string;
};
export default function IAIFormHelperText(props: IAIFormHelperTextProps) {
const { children, ...rest } = props;
return (
<FormHelperText margin={0} color="base.400" {...rest}>
{children}
</FormHelperText>
);
}

View File

@@ -1,25 +0,0 @@
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}
</Flex>
);
}

View File

@@ -1,25 +0,0 @@
import {
Checkbox,
CheckboxProps,
FormControl,
FormControlProps,
FormLabel,
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIFullCheckboxProps = CheckboxProps & {
label: string | ReactNode;
formControlProps?: FormControlProps;
};
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
const { label, formControlProps, ...rest } = props;
return (
<FormControl {...formControlProps}>
<FormLabel>{label}</FormLabel>
<Checkbox colorScheme="accent" {...rest} />
</FormControl>
);
};
export default memo(IAIFullCheckbox);

View File

@@ -1,6 +1,7 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
@@ -20,26 +21,37 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
const stylesFunc = useCallback(
() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
marginBottom: 4,
},
})}
{...rest}
/>
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal' as const,
marginBottom: 4,
},
}),
[
accent300,
accent500,
base100,
base200,
base300,
base50,
base700,
base800,
base900,
colorMode,
]
);
return <TextInput styles={stylesFunc} {...rest} />;
}

View File

@@ -98,28 +98,34 @@ const IAINumberInput = forwardRef((props: Props, ref) => {
}
}, [value, valueAsString]);
const handleOnChange = (v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
};
const handleOnChange = useCallback(
(v: string) => {
setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
onChange(isInteger ? Math.floor(Number(v)) : Number(v));
}
},
[isInteger, onChange]
);
/**
* Clicking the steppers allows the value to go outside bounds; we need to
* clamp it on blur and floor it if needed.
*/
const handleBlur = (e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
};
const handleBlur = useCallback(
(e: FocusEvent<HTMLInputElement>) => {
const clamped = clamp(
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
min,
max
);
setValueAsString(String(clamped));
onChange(clamped);
},
[isInteger, max, min, onChange]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {

View File

@@ -6,7 +6,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { memo, MouseEvent } from 'react';
import { memo, MouseEvent, useCallback } from 'react';
import IAIOption from './IAIOption';
type IAISelectProps = SelectProps & {
@@ -33,15 +33,16 @@ const IAISelect = (props: IAISelectProps) => {
spaceEvenly,
...rest
} = props;
const handleClick = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}, []);
return (
<FormControl
isDisabled={isDisabled}
onClick={(e: MouseEvent<HTMLDivElement>) => {
e.stopPropagation();
e.nativeEvent.stopImmediatePropagation();
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}}
onClick={handleClick}
sx={
horizontal
? {

View File

@@ -186,6 +186,13 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
[dispatch]
);
const handleMouseEnter = useCallback(() => setShowTooltip(true), []);
const handleMouseLeave = useCallback(() => setShowTooltip(false), []);
const handleStepperClick = useCallback(
() => onChange(Number(localInputValue)),
[localInputValue, onChange]
);
return (
<FormControl
ref={ref}
@@ -219,8 +226,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
max={max}
step={step}
onChange={handleSliderChange}
onMouseEnter={() => setShowTooltip(true)}
onMouseLeave={() => setShowTooltip(false)}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
focusThumbOnChange={false}
isDisabled={isDisabled}
{...rest}
@@ -332,12 +339,8 @@ const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
{...sliderNumberInputFieldProps}
/>
<NumberInputStepper {...sliderNumberInputStepperProps}>
<NumberIncrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
<NumberDecrementStepper
onClick={() => onChange(Number(localInputValue))}
/>
<NumberIncrementStepper onClick={handleStepperClick} />
<NumberDecrementStepper onClick={handleStepperClick} />
</NumberInputStepper>
</NumberInput>
)}

View File

@@ -146,16 +146,15 @@ const ImageUploader = (props: ImageUploaderProps) => {
};
}, [inputRef]);
const handleKeyDown = useCallback((e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}, []);
return (
<Box
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') {
return;
}
}}
>
<Box {...getRootProps({ style: {} })} onKeyDown={handleKeyDown}>
<input {...getInputProps()} />
{children}
<AnimatePresence>

View File

@@ -1,23 +0,0 @@
import { Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
// bg: 'base.800',
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',
}}
>
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
</Flex>
);
};
export default memo(SelectImagePlaceholder);

View File

@@ -1,24 +0,0 @@
import { useBreakpoint } from '@chakra-ui/react';
export default function useResolution():
| 'mobile'
| 'tablet'
| 'desktop'
| 'unknown' {
const breakpointValue = useBreakpoint();
const mobileResolutions = ['base', 'sm'];
const tabletResolutions = ['md', 'lg'];
const desktopResolutions = ['xl', '2xl'];
if (mobileResolutions.includes(breakpointValue)) {
return 'mobile';
}
if (tabletResolutions.includes(breakpointValue)) {
return 'tablet';
}
if (desktopResolutions.includes(breakpointValue)) {
return 'desktop';
}
return 'unknown';
}

View File

@@ -1,7 +0,0 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@@ -1,71 +0,0 @@
// TODO: Restore variations
// Support code from v2.3 in here.
// export const stringToSeedWeights = (
// string: string
// ): InvokeAI.SeedWeights | boolean => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
// return { seed: Number(p[0]), weight: Number(p[1]) };
// });
// if (!validateSeedWeights(pairs)) {
// return false;
// }
// return pairs;
// };
// export const validateSeedWeights = (
// seedWeights: InvokeAI.SeedWeights | string
// ): boolean => {
// return typeof seedWeights === 'string'
// ? Boolean(stringToSeedWeights(seedWeights))
// : Boolean(
// seedWeights.length &&
// !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
// const { seed, weight } = pair;
// const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
// const isWeightValid =
// !isNaN(parseInt(weight.toString(), 10)) &&
// weight >= 0 &&
// weight <= 1;
// return !(isSeedValid && isWeightValid);
// })
// );
// };
// export const seedWeightsToString = (
// seedWeights: InvokeAI.SeedWeights
// ): string => {
// return seedWeights.reduce((acc, pair, i, arr) => {
// const { seed, weight } = pair;
// acc += `${seed}:${weight}`;
// if (i !== arr.length - 1) {
// acc += ',';
// }
// return acc;
// }, '');
// };
// export const seedWeightsToArray = (
// seedWeights: InvokeAI.SeedWeights
// ): Array<Array<number>> => {
// return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
// pair.seed,
// pair.weight,
// ]);
// };
// export const stringToSeedWeightsArray = (
// string: string
// ): Array<Array<number>> => {
// const stringPairs = string.split(',');
// const arrPairs = stringPairs.map((p) => p.split(':'));
// return arrPairs.map(
// (p: Array<string>): Array<number> => [parseInt(p[0], 10), parseFloat(p[1])]
// );
// };
export default {};

View File

@@ -5,17 +5,22 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { isStagingSelector } from '../store/canvasSelectors';
import { memo } from 'react';
import { memo, useCallback } from 'react';
const ClearCanvasHistoryButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const acceptCallback = useCallback(
() => dispatch(clearCanvasHistory()),
[dispatch]
);
return (
<IAIAlertDialog
title={t('unifiedCanvas.clearCanvasHistory')}
acceptCallback={() => dispatch(clearCanvasHistory())}
acceptCallback={acceptCallback}
acceptButtonText={t('unifiedCanvas.clearHistory')}
triggerComponent={
<IAIButton size="sm" leftIcon={<FaTrash />} isDisabled={isStaging}>

View File

@@ -20,7 +20,8 @@ import {
} from 'features/canvas/store/canvasSlice';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash-es';
import { memo } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { RgbaColor } from 'react-colorful';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -95,18 +96,35 @@ const IAICanvasMaskOptions = () => {
[isMaskEnabled]
);
const handleToggleMaskLayer = () => {
const handleToggleMaskLayer = useCallback(() => {
dispatch(setLayer(layer === 'mask' ? 'base' : 'mask'));
};
}, [dispatch, layer]);
const handleClearMask = () => dispatch(clearMask());
const handleClearMask = useCallback(() => {
dispatch(clearMask());
}, [dispatch]);
const handleToggleEnableMask = () =>
const handleToggleEnableMask = useCallback(() => {
dispatch(setIsMaskEnabled(!isMaskEnabled));
}, [dispatch, isMaskEnabled]);
const handleSaveMask = async () => {
const handleSaveMask = useCallback(async () => {
dispatch(canvasMaskSavedToGallery());
};
}, [dispatch]);
const handleChangePreserveMaskedArea = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldPreserveMaskedArea(e.target.checked));
},
[dispatch]
);
const handleChangeMaskColor = useCallback(
(newColor: RgbaColor) => {
dispatch(setMaskColor(newColor));
},
[dispatch]
);
return (
<IAIPopover
@@ -131,15 +149,10 @@ const IAICanvasMaskOptions = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.preserveMaskedArea')}
isChecked={shouldPreserveMaskedArea}
onChange={(e) =>
dispatch(setShouldPreserveMaskedArea(e.target.checked))
}
onChange={handleChangePreserveMaskedArea}
/>
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
<IAIColorPicker
color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
<IAIColorPicker color={maskColor} onChange={handleChangeMaskColor} />
</Box>
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
Save Mask

View File

@@ -10,6 +10,7 @@ import { redo } from 'features/canvas/store/canvasSlice';
import { stateSelector } from 'app/store/store';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { useCallback } from 'react';
const canvasRedoSelector = createSelector(
[stateSelector, activeTabNameSelector],
@@ -34,9 +35,9 @@ export default function IAICanvasRedoButton() {
const { t } = useTranslation();
const handleRedo = () => {
const handleRedo = useCallback(() => {
dispatch(redo());
};
}, [dispatch]);
useHotkeys(
['meta+shift+z', 'ctrl+shift+z', 'control+y', 'meta+y'],

View File

@@ -18,7 +18,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaWrench } from 'react-icons/fa';
@@ -86,8 +86,52 @@ const IAICanvasSettingsButtonPopover = () => {
[shouldSnapToGrid]
);
const handleChangeShouldSnapToGrid = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldSnapToGrid(e.target.checked));
const handleChangeShouldSnapToGrid = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldSnapToGrid(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowIntermediates = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowIntermediates(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowGrid = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowGrid(e.target.checked)),
[dispatch]
);
const handleChangeShouldDarkenOutsideBoundingBox = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldAutoSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAutoSave(e.target.checked)),
[dispatch]
);
const handleChangeShouldCropToBoundingBoxOnSave = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked)),
[dispatch]
);
const handleChangeShouldRestrictStrokesToBox = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRestrictStrokesToBox(e.target.checked)),
[dispatch]
);
const handleChangeShouldShowCanvasDebugInfo = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldShowCanvasDebugInfo(e.target.checked)),
[dispatch]
);
const handleChangeShouldAntialias = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAntialias(e.target.checked)),
[dispatch]
);
return (
<IAIPopover
@@ -104,14 +148,12 @@ const IAICanvasSettingsButtonPopover = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.showIntermediates')}
isChecked={shouldShowIntermediates}
onChange={(e) =>
dispatch(setShouldShowIntermediates(e.target.checked))
}
onChange={handleChangeShouldShowIntermediates}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.showGrid')}
isChecked={shouldShowGrid}
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
onChange={handleChangeShouldShowGrid}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.snapToGrid')}
@@ -121,41 +163,33 @@ const IAICanvasSettingsButtonPopover = () => {
<IAISimpleCheckbox
label={t('unifiedCanvas.darkenOutsideSelection')}
isChecked={shouldDarkenOutsideBoundingBox}
onChange={(e) =>
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
}
onChange={handleChangeShouldDarkenOutsideBoundingBox}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.autoSaveToGallery')}
isChecked={shouldAutoSave}
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
onChange={handleChangeShouldAutoSave}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.saveBoxRegionOnly')}
isChecked={shouldCropToBoundingBoxOnSave}
onChange={(e) =>
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
}
onChange={handleChangeShouldCropToBoundingBoxOnSave}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.limitStrokesToBox')}
isChecked={shouldRestrictStrokesToBox}
onChange={(e) =>
dispatch(setShouldRestrictStrokesToBox(e.target.checked))
}
onChange={handleChangeShouldRestrictStrokesToBox}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.showCanvasDebugInfo')}
isChecked={shouldShowCanvasDebugInfo}
onChange={(e) =>
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
}
onChange={handleChangeShouldShowCanvasDebugInfo}
/>
<IAISimpleCheckbox
label={t('unifiedCanvas.antialiasing')}
isChecked={shouldAntialias}
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
onChange={handleChangeShouldAntialias}
/>
<ClearCanvasHistoryButtonModal />
</Flex>

View File

@@ -15,7 +15,8 @@ import {
setTool,
} from 'features/canvas/store/canvasSlice';
import { clamp, isEqual } from 'lodash-es';
import { memo } from 'react';
import { memo, useCallback } from 'react';
import { RgbaColor } from 'react-colorful';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -172,11 +173,33 @@ const IAICanvasToolChooserOptions = () => {
[brushColor]
);
const handleSelectBrushTool = () => dispatch(setTool('brush'));
const handleSelectEraserTool = () => dispatch(setTool('eraser'));
const handleSelectColorPickerTool = () => dispatch(setTool('colorPicker'));
const handleFillRect = () => dispatch(addFillRect());
const handleEraseBoundingBox = () => dispatch(addEraseRect());
const handleSelectBrushTool = useCallback(() => {
dispatch(setTool('brush'));
}, [dispatch]);
const handleSelectEraserTool = useCallback(() => {
dispatch(setTool('eraser'));
}, [dispatch]);
const handleSelectColorPickerTool = useCallback(() => {
dispatch(setTool('colorPicker'));
}, [dispatch]);
const handleFillRect = useCallback(() => {
dispatch(addFillRect());
}, [dispatch]);
const handleEraseBoundingBox = useCallback(() => {
dispatch(addEraseRect());
}, [dispatch]);
const handleChangeBrushSize = useCallback(
(newSize: number) => {
dispatch(setBrushSize(newSize));
},
[dispatch]
);
const handleChangeBrushColor = useCallback(
(newColor: RgbaColor) => {
dispatch(setBrushColor(newColor));
},
[dispatch]
);
return (
<ButtonGroup isAttached>
@@ -233,7 +256,7 @@ const IAICanvasToolChooserOptions = () => {
label={t('unifiedCanvas.brushSize')}
value={brushSize}
withInput
onChange={(newSize) => dispatch(setBrushSize(newSize))}
onChange={handleChangeBrushSize}
sliderNumberInputProps={{ max: 500 }}
/>
</Flex>
@@ -247,7 +270,7 @@ const IAICanvasToolChooserOptions = () => {
<IAIColorPicker
withNumberInput={true}
color={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
onChange={handleChangeBrushColor}
/>
</Box>
</Flex>

View File

@@ -25,9 +25,9 @@ import {
LAYER_NAMES_DICT,
} from 'features/canvas/store/canvasTypes';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { isEqual } from 'lodash-es';
import { memo } from 'react';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@@ -151,7 +151,9 @@ const IAICanvasToolbar = () => {
[canvasBaseLayer]
);
const handleSelectMoveTool = () => dispatch(setTool('move'));
const handleSelectMoveTool = useCallback(() => {
dispatch(setTool('move'));
}, [dispatch]);
const handleClickResetCanvasView = useSingleAndDoubleClick(
() => handleResetCanvasView(false),
@@ -174,36 +176,39 @@ const IAICanvasToolbar = () => {
);
};
const handleResetCanvas = () => {
const handleResetCanvas = useCallback(() => {
dispatch(resetCanvas());
};
}, [dispatch]);
const handleMergeVisible = () => {
const handleMergeVisible = useCallback(() => {
dispatch(canvasMerged());
};
}, [dispatch]);
const handleSaveToGallery = () => {
const handleSaveToGallery = useCallback(() => {
dispatch(canvasSavedToGallery());
};
}, [dispatch]);
const handleCopyImageToClipboard = () => {
const handleCopyImageToClipboard = useCallback(() => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard());
};
}, [dispatch, isClipboardAPIAvailable]);
const handleDownloadAsImage = () => {
const handleDownloadAsImage = useCallback(() => {
dispatch(canvasDownloadedAsImage());
};
}, [dispatch]);
const handleChangeLayer = (v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
}
};
const handleChangeLayer = useCallback(
(v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
}
},
[dispatch, isMaskEnabled]
);
return (
<Flex

View File

@@ -10,6 +10,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { stateSelector } from 'app/store/store';
import { useCallback } from 'react';
const canvasUndoSelector = createSelector(
[stateSelector, activeTabNameSelector],
@@ -35,9 +36,9 @@ export default function IAICanvasUndoButton() {
const { canUndo, activeTabName } = useAppSelector(canvasUndoSelector);
const handleUndo = () => {
const handleUndo = useCallback(() => {
dispatch(undo());
};
}, [dispatch]);
useHotkeys(
['meta+z', 'ctrl+z'],

View File

@@ -1,16 +0,0 @@
import Konva from 'konva';
import { IRect } from 'konva/lib/types';
/**
* Converts a Konva node to a dataURL
* @param node - The Konva node to convert to a dataURL
* @param boundingBox - The bounding box to crop to
* @returns A dataURL of the node cropped to the bounding box
*/
export const konvaNodeToDataURL = (
node: Konva.Node,
boundingBox: IRect
): string => {
// get a dataURL of the bbox'd region
return node.toDataURL(boundingBox);
};

View File

@@ -87,6 +87,11 @@ const ChangeBoardModal = () => {
selectedBoard,
]);
const handleSetSelectedBoard = useCallback(
(v: string | null) => setSelectedBoard(v),
[]
);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
@@ -113,7 +118,7 @@ const ChangeBoardModal = () => {
isFetching ? t('boards.loading') : t('boards.selectBoard')
}
disabled={isFetching}
onChange={(v) => setSelectedBoard(v)}
onChange={handleSetSelectedBoard}
value={selectedBoard}
data={data}
/>

View File

@@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
type Props = {
id: string;
};
const ControlAdapterPreprocessButton = ({ id }: Props) => {
const controlImage = useControlAdapterControlImage(id);
const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => {
dispatch(
controlAdapterImageProcessed({
id,
})
);
}, [id, dispatch]);
return (
<IAIButton
size="sm"
onClick={handleProcess}
isDisabled={Boolean(!controlImage) || !isReady}
>
Preprocess
</IAIButton>
);
};
export default memo(ControlAdapterPreprocessButton);

View File

@@ -14,9 +14,9 @@ import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectI
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { useTranslation } from 'react-i18next';
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
@@ -78,6 +78,13 @@ const ParamEmbeddingPopover = (props: Props) => {
[onSelect]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
return (
<Popover
initialFocusRef={inputRef}
@@ -127,12 +134,7 @@ const ParamEmbeddingPopover = (props: Props) => {
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
onDropdownClose={onClose}
filter={(value, item: SelectItem) =>
item.label
?.toLowerCase()
.includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
filter={filterFunc}
onChange={handleChange}
/>
)}

View File

@@ -60,6 +60,13 @@ const BoardAutoAddSelect = () => {
[dispatch]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
return (
<IAIMantineSearchableSelect
label={t('boards.autoAddBoard')}
@@ -71,10 +78,7 @@ const BoardAutoAddSelect = () => {
nothingFound={t('boards.noMatching')}
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={!hasBoards || autoAssignBoardOnClick}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
filter={filterFunc}
onChange={handleChange}
/>
);

View File

@@ -90,6 +90,50 @@ const BoardContextMenu = ({
e.preventDefault();
}, []);
const renderMenuFunc = useCallback(
() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup title={boardName}>
<MenuItem
icon={<FaPlus />}
isDisabled={isAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{isBulkDownloadEnabled && (
<MenuItem icon={<FaDownload />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
{!board && <NoBoardContextMenuItems />}
{board && (
<GalleryBoardContextMenuItems
board={board}
setBoardToDelete={setBoardToDelete}
/>
)}
</MenuGroup>
</MenuList>
),
[
autoAssignBoardOnClick,
board,
boardName,
handleBulkDownload,
handleSetAutoAdd,
isAutoAdd,
isBulkDownloadEnabled,
setBoardToDelete,
skipEvent,
t,
]
);
return (
<IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
@@ -97,38 +141,7 @@ const BoardContextMenu = ({
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuGroup title={boardName}>
<MenuItem
icon={<FaPlus />}
isDisabled={isAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{isBulkDownloadEnabled && (
<MenuItem
icon={<FaDownload />}
onClickCapture={handleBulkDownload}
>
{t('boards.downloadBoard')}
</MenuItem>
)}
{!board && <NoBoardContextMenuItems />}
{board && (
<GalleryBoardContextMenuItems
board={board}
setBoardToDelete={setBoardToDelete}
/>
)}
</MenuGroup>
</MenuList>
)}
renderMenu={renderMenuFunc}
>
{children}
</IAIContextMenu>

View File

@@ -1,108 +0,0 @@
import { As, Badge, Flex } from '@chakra-ui/react';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types';
import { ReactNode, memo } from 'react';
import BoardContextMenu from '../BoardContextMenu';
type GenericBoardProps = {
board_id: BoardId;
droppableData?: TypesafeDroppableData;
onClick: () => void;
isSelected: boolean;
icon: As;
label: string;
dropLabel?: ReactNode;
badgeCount?: number;
};
export const formatBadgeCount = (count: number) =>
Intl.NumberFormat('en-US', {
notation: 'compact',
maximumFractionDigits: 1,
}).format(count);
const GenericBoard = (props: GenericBoardProps) => {
const {
board_id,
droppableData,
onClick,
isSelected,
icon,
label,
badgeCount,
dropLabel,
} = props;
return (
<BoardContextMenu board_id={board_id}>
{(ref) => (
<Flex
ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
>
<Flex
onClick={onClick}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}}
>
<IAINoContentFallback
boxSize={8}
icon={icon}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
{badgeCount !== undefined && (
<Badge variant="solid">{formatBadgeCount(badgeCount)}</Badge>
)}
</Flex>
<IAIDroppable data={droppableData} dropLabel={dropLabel} />
</Flex>
<Flex
sx={{
h: 'full',
alignItems: 'center',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'sm',
color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' },
}}
>
{label}
</Flex>
</Flex>
)}
</BoardContextMenu>
);
};
export default memo(GenericBoard);

View File

@@ -1,53 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
board_id: 'images' | 'assets' | 'no_board';
};
const SystemBoardButton = ({ board_id }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const { selectedBoardId } = gallery;
return { isSelected: selectedBoardId === board_id };
},
defaultSelectorOptions
),
[board_id]
);
const { isSelected } = useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleClick = useCallback(() => {
dispatch(boardIdSelected({ boardId: board_id }));
}, [board_id, dispatch]);
return (
<IAIButton
onClick={handleClick}
size="sm"
isChecked={isSelected}
sx={{
flexGrow: 1,
borderRadius: 'base',
}}
>
{boardName}
</IAIButton>
);
};
export default memo(SystemBoardButton);

View File

@@ -1,22 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import { FaEyeSlash } from 'react-icons/fa';
const CurrentImageHidden = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<FaEyeSlash fontSize="25vh" />
</Flex>
);
};
export default memo(CurrentImageHidden);

View File

@@ -61,6 +61,12 @@ const GallerySettingsPopover = () => {
[dispatch]
);
const handleChangeAutoAssignBoardOnClick = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
return (
<IAIPopover
triggerComponent={
@@ -91,9 +97,7 @@ const GallerySettingsPopover = () => {
<IAISimpleCheckbox
label={t('gallery.autoAssignBoardOnClick')}
isChecked={autoAssignBoardOnClick}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(autoAssignBoardOnClickChanged(e.target.checked))
}
onChange={handleChangeAutoAssignBoardOnClick}
/>
<BoardAutoAddSelect />
</Flex>

View File

@@ -35,6 +35,34 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
e.preventDefault();
}, []);
const renderMenuFunc = useCallback(() => {
if (!imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MultipleSelectionMenuItems />
</MenuList>
);
}
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<SingleSelectionMenuItems imageDTO={imageDTO} />
</MenuList>
);
}, [imageDTO, selectionCount, skipEvent]);
return (
<IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
@@ -42,33 +70,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={() => {
if (!imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MultipleSelectionMenuItems />
</MenuList>
);
}
return (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<SingleSelectionMenuItems imageDTO={imageDTO} />
</MenuList>
);
}}
renderMenu={renderMenuFunc}
>
{children}
</IAIContextMenu>

View File

@@ -13,7 +13,7 @@ import { workflowLoadRequested } from 'features/nodes/store/actions';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { flushSync } from 'react-dom';

View File

@@ -1,27 +0,0 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
import { memo } from 'react';
type ImageFallbackSpinnerProps = SpinnerProps;
const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
minH: 36,
minW: 36,
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default memo(ImageFallbackSpinner);

View File

@@ -20,6 +20,7 @@ import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
import GalleryImage from './GalleryImage';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
import { EntityId } from '@reduxjs/toolkit';
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
defer: true,
@@ -71,6 +72,13 @@ const GalleryImageGrid = () => {
});
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
const itemContentFunc = useCallback(
(index: number, imageName: EntityId) => (
<GalleryImage key={imageName} imageName={imageName as string} />
),
[]
);
useEffect(() => {
// Initialize the gallery's custom scrollbar
const { current: root } = rootRef;
@@ -131,9 +139,7 @@ const GalleryImageGrid = () => {
List: ImageGridListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, imageName) => (
<GalleryImage key={imageName} imageName={imageName as string} />
)}
itemContent={itemContentFunc}
/>
</Box>
<IAIButton

View File

@@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
onClick={() => handleRecallLoRA(lora)}
onClick={handleRecallLoRA.bind(null, lora)}
/>
);
}
@@ -289,7 +289,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
onClick={() => handleRecallControlNet(controlnet)}
onClick={handleRecallControlNet.bind(null, controlnet)}
/>
))}
{validIPAdapters.map((ipAdapter, index) => (
@@ -297,7 +297,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
onClick={() => handleRecallIPAdapter(ipAdapter)}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/>
))}
{validT2IAdapters.map((t2iAdapter, index) => (
@@ -305,7 +305,7 @@ const ImageMetadataActions = (props: Props) => {
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
onClick={() => handleRecallT2IAdapter(t2iAdapter)}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/>
))}
</>

View File

@@ -1,6 +1,6 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
import { memo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
@@ -27,6 +27,11 @@ const ImageMetadataItem = ({
}: MetadataItemProps) => {
const { t } = useTranslation();
const handleCopy = useCallback(
() => navigator.clipboard.writeText(value.toString()),
[value]
);
if (!value) {
return null;
}
@@ -53,7 +58,7 @@ const ImageMetadataItem = ({
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
onClick={handleCopy}
/>
</Tooltip>
)}

View File

@@ -76,6 +76,13 @@ const ParamLoRASelect = () => {
[dispatch, loraModels?.entities]
);
const filterFunc = useCallback(
(value: string, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()),
[]
);
if (loraModels?.ids.length === 0) {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
@@ -94,10 +101,7 @@ const ParamLoRASelect = () => {
nothingFound="No matching LoRAs"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
filter={filterFunc}
onChange={handleChange}
data-testid="add-lora"
/>

View File

@@ -1,6 +1,6 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import { useCallback, useState } from 'react';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
@@ -8,6 +8,11 @@ export default function AddModels() {
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
'simple'
);
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
const handleAddModelAdvanced = useCallback(
() => setAddModelMode('advanced'),
[]
);
return (
<Flex
flexDirection="column"
@@ -20,14 +25,14 @@ export default function AddModels() {
<IAIButton
size="sm"
isChecked={addModelMode == 'simple'}
onClick={() => setAddModelMode('simple')}
onClick={handleAddModelSimple}
>
Simple
</IAIButton>
<IAIButton
size="sm"
isChecked={addModelMode == 'advanced'}
onClick={() => setAddModelMode('advanced')}
onClick={handleAddModelAdvanced}
>
Advanced
</IAIButton>

View File

@@ -6,7 +6,7 @@ import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useState } from 'react';
import { FocusEventHandler, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
@@ -83,6 +83,27 @@ export default function AdvancedAddCheckpoint(
});
};
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
useCallback(
(e) => {
if (advancedAddCheckpointForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
advancedAddCheckpointForm.setFieldValue(
'model_name',
modelName as string
);
}
}
},
[advancedAddCheckpointForm]
);
const handleChangeUseCustomConfig = useCallback(
() => setUseCustomConfig((prev) => !prev),
[]
);
return (
<form
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
@@ -104,17 +125,7 @@ export default function AdvancedAddCheckpoint(
label={t('modelManager.modelLocation')}
required
{...advancedAddCheckpointForm.getInputProps('path')}
onBlur={(e) => {
if (advancedAddCheckpointForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
advancedAddCheckpointForm.setFieldValue(
'model_name',
modelName as string
);
}
}
}}
onBlur={handleBlurModelLocation}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
@@ -144,7 +155,7 @@ export default function AdvancedAddCheckpoint(
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
onChange={handleChangeUseCustomConfig}
label={t('modelManager.useCustomConfig')}
/>
<IAIButton mt={2} type="submit">

View File

@@ -12,6 +12,7 @@ import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import { getModelName } from './util';
import { FocusEventHandler, useCallback } from 'react';
type AdvancedAddDiffusersProps = {
model_path?: string;
@@ -74,6 +75,22 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
});
};
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
useCallback(
(e) => {
if (advancedAddDiffusersForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
advancedAddDiffusersForm.setFieldValue(
'model_name',
modelName as string
);
}
}
},
[advancedAddDiffusersForm]
);
return (
<form
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
@@ -96,17 +113,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
label={t('modelManager.modelLocation')}
placeholder={t('modelManager.modelLocationValidationMsg')}
{...advancedAddDiffusersForm.getInputProps('path')}
onBlur={(e) => {
if (advancedAddDiffusersForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
advancedAddDiffusersForm.setFieldValue(
'model_name',
modelName as string
);
}
}
}}
onBlur={handleBlurModelLocation}
/>
<IAIMantineTextInput
label={t('modelManager.description')}

View File

@@ -1,7 +1,7 @@
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useState } from 'react';
import { useCallback, useState } from 'react';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import { useTranslation } from 'react-i18next';
@@ -18,6 +18,12 @@ export default function AdvancedAddModels() {
useState<ManualAddMode>('diffusers');
const { t } = useTranslation();
const handleChange = useCallback((v: string | null) => {
if (!v) {
return;
}
setAdvancedAddMode(v as ManualAddMode);
}, []);
return (
<Flex flexDirection="column" gap={4} width="100%">
@@ -25,12 +31,7 @@ export default function AdvancedAddModels() {
label={t('modelManager.modelType')}
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) {
return;
}
setAdvancedAddMode(v as ManualAddMode);
}}
onChange={handleChange}
/>
<Flex

View File

@@ -92,6 +92,11 @@ export default function FoundModelsList() {
setNameFilter(e.target.value);
}, []);
const handleClickSetAdvanced = useCallback(
(model: string) => dispatch(setAdvancedAddScanModel(model)),
[dispatch]
);
const renderModels = ({
models,
showActions = true,
@@ -140,7 +145,7 @@ export default function FoundModelsList() {
{t('modelManager.quickAdd')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAdvancedAddScanModel(model))}
onClick={handleClickSetAdvanced.bind(null, model)}
isLoading={isLoading}
>
{t('modelManager.advanced')}

View File

@@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { motion } from 'framer-motion';
import { useEffect, useState } from 'react';
import { useCallback, useEffect, useState } from 'react';
import { FaTimes } from 'react-icons/fa';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
@@ -35,6 +35,23 @@ export default function ScanAdvancedAddModels() {
const dispatch = useAppDispatch();
const handleClickSetAdvanced = useCallback(
() => dispatch(setAdvancedAddScanModel(null)),
[dispatch]
);
const handleChangeAddMode = useCallback((v: string | null) => {
if (!v) {
return;
}
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}, []);
if (!advancedAddScanModel) {
return null;
}
@@ -68,7 +85,7 @@ export default function ScanAdvancedAddModels() {
<IAIIconButton
icon={<FaTimes />}
aria-label={t('modelManager.closeAdvanced')}
onClick={() => dispatch(setAdvancedAddScanModel(null))}
onClick={handleClickSetAdvanced}
size="sm"
/>
</Flex>
@@ -76,17 +93,7 @@ export default function ScanAdvancedAddModels() {
label={t('modelManager.modelType')}
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) {
return;
}
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}}
onChange={handleChangeAddMode}
/>
{isCheckpoint ? (
<AdvancedAddCheckpoint

View File

@@ -42,9 +42,14 @@ function SearchFolderForm() {
[dispatch]
);
const scanAgainHandler = () => {
const scanAgainHandler = useCallback(() => {
refetchFoundModels();
};
}, [refetchFoundModels]);
const handleClickClearCheckpointFolder = useCallback(() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}, [dispatch]);
return (
<form
@@ -123,10 +128,7 @@ function SearchFolderForm() {
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
size="sm"
onClick={() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}}
onClick={handleClickClearCheckpointFolder}
isDisabled={!searchFolder}
colorScheme="red"
/>

View File

@@ -1,6 +1,6 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
@@ -11,11 +11,14 @@ export default function ImportModelsPanel() {
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
const handleClickAddTab = useCallback(() => setAddModelTab('add'), []);
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
return (
<Flex flexDirection="column" gap={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setAddModelTab('add')}
onClick={handleClickAddTab}
isChecked={addModelTab == 'add'}
size="sm"
width="100%"
@@ -23,7 +26,7 @@ export default function ImportModelsPanel() {
{t('modelManager.addModel')}
</IAIButton>
<IAIButton
onClick={() => setAddModelTab('scan')}
onClick={handleClickScanTab}
isChecked={addModelTab == 'scan'}
size="sm"
width="100%"

View File

@@ -9,7 +9,7 @@ import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es';
import { useMemo, useState } from 'react';
import { ChangeEvent, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
@@ -94,13 +94,58 @@ export default function MergeModelsPanel() {
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelTwo);
const handleBaseModelChange = (v: string) => {
const handleBaseModelChange = useCallback((v: string) => {
setBaseModel(v as BaseModelType);
setModelOne(null);
setModelTwo(null);
};
}, []);
const mergeModelsHandler = () => {
const handleChangeModelOne = useCallback((v: string) => {
setModelOne(v);
}, []);
const handleChangeModelTwo = useCallback((v: string) => {
setModelTwo(v);
}, []);
const handleChangeModelThree = useCallback((v: string) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree(v);
setModelMergeInterp('weighted_sum');
}
}, []);
const handleChangeMergedModelName = useCallback(
(e: ChangeEvent<HTMLInputElement>) => setMergedModelName(e.target.value),
[]
);
const handleChangeModelMergeAlpha = useCallback(
(v: number) => setModelMergeAlpha(v),
[]
);
const handleResetModelMergeAlpha = useCallback(
() => setModelMergeAlpha(0.5),
[]
);
const handleChangeMergeInterp = useCallback(
(v: MergeInterpolationMethods) => setModelMergeInterp(v),
[]
);
const handleChangeMergeSaveLocType = useCallback(
(v: 'root' | 'custom') => setModelMergeSaveLocType(v),
[]
);
const handleChangeMergeCustomSaveLoc = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
setModelMergeCustomSaveLoc(e.target.value),
[]
);
const handleChangeModelMergeForce = useCallback(
(e: ChangeEvent<HTMLInputElement>) => setModelMergeForce(e.target.checked),
[]
);
const mergeModelsHandler = useCallback(() => {
const models_names: string[] = [];
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
@@ -150,7 +195,21 @@ export default function MergeModelsPanel() {
);
}
});
};
}, [
baseModel,
dispatch,
mergeModels,
mergedModelName,
modelMergeAlpha,
modelMergeCustomSaveLoc,
modelMergeForce,
modelMergeInterp,
modelMergeSaveLocType,
modelOne,
modelThree,
modelTwo,
t,
]);
return (
<Flex flexDirection="column" rowGap={4}>
@@ -180,7 +239,7 @@ export default function MergeModelsPanel() {
value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
onChange={handleChangeModelOne}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelTwo')}
@@ -188,7 +247,7 @@ export default function MergeModelsPanel() {
placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
onChange={handleChangeModelTwo}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelThree')}
@@ -196,22 +255,14 @@ export default function MergeModelsPanel() {
w="100%"
placeholder={t('modelManager.selectModel')}
clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree(v);
setModelMergeInterp('weighted_sum');
}
}}
onChange={handleChangeModelThree}
/>
</Flex>
<IAIInput
label={t('modelManager.mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
onChange={handleChangeMergedModelName}
/>
<Flex
@@ -232,10 +283,10 @@ export default function MergeModelsPanel() {
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
onChange={handleChangeModelMergeAlpha}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
handleReset={handleResetModelMergeAlpha}
withSliderMarks
/>
<Text variant="subtext" fontSize="sm">
@@ -257,10 +308,7 @@ export default function MergeModelsPanel() {
<Text fontWeight={500} fontSize="sm" variant="subtext">
{t('modelManager.interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
>
<RadioGroup value={modelMergeInterp} onChange={handleChangeMergeInterp}>
<Flex columnGap={4}>
{modelThree === null ? (
<>
@@ -305,7 +353,7 @@ export default function MergeModelsPanel() {
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') => setModelMergeSaveLocType(v)}
onChange={handleChangeMergeSaveLocType}
>
<Flex columnGap={4}>
<Radio value="root">
@@ -323,7 +371,7 @@ export default function MergeModelsPanel() {
<IAIInput
label={t('modelManager.mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
onChange={handleChangeMergeCustomSaveLoc}
/>
)}
</Flex>
@@ -331,7 +379,7 @@ export default function MergeModelsPanel() {
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
onChange={handleChangeModelMergeForce}
fontWeight="500"
/>

View File

@@ -59,6 +59,11 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
},
});
const handleChangeUseCustomConfig = useCallback(
() => setUseCustomConfig((prev) => !prev),
[]
);
const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const responseBody = {
@@ -181,7 +186,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
onChange={handleChangeUseCustomConfig}
label="Use Custom Config"
/>
</Flex>

View File

@@ -14,7 +14,7 @@ import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { ChangeEvent, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
@@ -42,11 +42,21 @@ export default function ModelConvert(props: ModelConvertProps) {
setSaveLocation('InvokeAIRoot');
}, [model]);
const modelConvertCancelHandler = () => {
const modelConvertCancelHandler = useCallback(() => {
setSaveLocation('InvokeAIRoot');
};
}, []);
const modelConvertHandler = () => {
const handleChangeSaveLocation = useCallback((v: string) => {
setSaveLocation(v as SaveLocation);
}, []);
const handleChangeCustomSaveLocation = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
setCustomSaveLocation(e.target.value);
},
[]
);
const modelConvertHandler = useCallback(() => {
const queryArg = {
base_model: model.base_model,
model_name: model.model_name,
@@ -101,7 +111,15 @@ export default function ModelConvert(props: ModelConvertProps) {
)
);
});
};
}, [
convertModel,
customSaveLocation,
dispatch,
model.base_model,
model.model_name,
saveLocation,
t,
]);
return (
<IAIAlertDialog
@@ -137,10 +155,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
</Text>
<RadioGroup
value={saveLocation}
onChange={(v) => setSaveLocation(v as SaveLocation)}
>
<RadioGroup value={saveLocation} onChange={handleChangeSaveLocation}>
<Flex gap={4}>
<Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the InvokeAI root folder">
@@ -162,9 +177,7 @@ export default function ModelConvert(props: ModelConvertProps) {
</Text>
<IAIInput
value={customSaveLocation}
onChange={(e) => {
setCustomSaveLocation(e.target.value);
}}
onChange={handleChangeCustomSaveLocation}
width="full"
/>
</Flex>

View File

@@ -100,7 +100,7 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('all')}
onClick={setModelFormatFilter.bind(null, 'all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
@@ -108,35 +108,35 @@ const ModelList = (props: ModelListProps) => {
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
onClick={setModelFormatFilter.bind(null, 'diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
onClick={setModelFormatFilter.bind(null, 'checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('onnx')}
onClick={setModelFormatFilter.bind(null, 'onnx')}
isChecked={modelFormatFilter === 'onnx'}
>
{t('modelManager.onnxModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('olive')}
onClick={setModelFormatFilter.bind(null, 'olive')}
isChecked={modelFormatFilter === 'olive'}
>
{t('modelManager.oliveModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('lora')}
onClick={setModelFormatFilter.bind(null, 'lora')}
isChecked={modelFormatFilter === 'lora'}
>
{t('modelManager.loraModels')}

View File

@@ -4,6 +4,7 @@ import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSync } from 'react-icons/fa';
import { useSyncModelsMutation } from 'services/api/endpoints/models';
@@ -19,7 +20,7 @@ export default function SyncModelsButton(props: SyncModelsButtonProps) {
const [syncModels, { isLoading }] = useSyncModelsMutation();
const syncModelsHandler = () => {
const syncModelsHandler = useCallback(() => {
syncModels()
.unwrap()
.then((_) => {
@@ -44,7 +45,7 @@ export default function SyncModelsButton(props: SyncModelsButtonProps) {
);
}
});
};
}, [dispatch, syncModels, t]);
return !iconMode ? (
<IAIButton

View File

@@ -1,4 +1,4 @@
import { useState, PropsWithChildren, memo } from 'react';
import { useState, PropsWithChildren, memo, useCallback } from 'react';
import { useSelector } from 'react-redux';
import { createSelector } from '@reduxjs/toolkit';
import { Flex, Image, Text } from '@chakra-ui/react';
@@ -59,13 +59,13 @@ export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
const [isHovering, setIsHovering] = useState(false);
const handleMouseEnter = () => {
const handleMouseEnter = useCallback(() => {
setIsHovering(true);
};
}, []);
const handleMouseLeave = () => {
const handleMouseLeave = useCallback(() => {
setIsHovering(false);
};
}, []);
return (
<NodeWrapper

Some files were not shown because too many files have changed in this diff Show More