Compare commits

..

139 Commits

Author SHA1 Message Date
psychedelicious
143487a492 chore: bump version to v5.11.0 2025-05-13 14:04:45 +10:00
psychedelicious
203fa04295 feat(nodes): support bottleneck flag for nodes 2025-05-13 11:56:40 +10:00
Mary Hipp Rogers
954fce3c67 feat(ui): custom error toast support (#8001)
* support for custom error toast components, starting with usage limit

* add support for all usage limits

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-08 15:53:10 -04:00
Mary Hipp
821889148a easier way to override Whats New 2025-05-07 15:40:21 -04:00
Mary Hipp
4c248d8c2c refetch queue list on mount 2025-05-07 15:37:55 -04:00
Mary Hipp
deb75805d4 use the max for iterations passed in 2025-05-06 18:26:40 -04:00
Mary Hipp Rogers
93110654da Change feature to disable apiModels to chatGPT4oModels only (#7996)
* display credit column in queue list if shouldShowCredits is true

* change apiModels feature to chatGPT4oModels feature

* empty

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-06 14:37:03 -04:00
psychedelicious
ff0c48d532 chore(ui): prettier 2025-05-06 09:07:52 -04:00
psychedelicious
de18073814 feat(ui): support imagen3/chatgpt-4o models in canvas 2025-05-06 09:07:52 -04:00
psychedelicious
0708af9545 feat(ui): support imagen3/chatgpt-4o models in workflow editor 2025-05-06 09:07:52 -04:00
psychedelicious
1e85184c62 feat(nodes): add imagen3/chatgpt-4o field types 2025-05-06 09:07:52 -04:00
psychedelicious
11d3b8d944 feat(ui): add usage info to model picker 2025-05-06 09:07:52 -04:00
psychedelicious
bffd4afb96 chore(ui): typegen 2025-05-06 09:07:52 -04:00
psychedelicious
518a896521 feat(mm): add usage_info to model config 2025-05-06 09:07:52 -04:00
psychedelicious
2647ff141a feat(ui): add basic metadata to imagen3/chatgpt-4o graphs 2025-05-06 09:07:52 -04:00
Mary Hipp Rogers
ba0bac2aa5 add credits to queue item status changed (#7993)
* display credit column in queue list if shouldShowCredits is true

* add credits when queue item status changes

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-06 08:54:44 -04:00
psychedelicious
862e2a3e49 chore(ui): typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
d22fd32b05 typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
391e5b7f8c update schema 2025-05-05 16:09:13 -04:00
Mary Hipp
c9d2a5f59a display credit column in queue list if shouldShowCredits is true 2025-05-05 16:09:13 -04:00
Kent Keirsey
1f63b60021 Implementing support for Non-Standard LoRA Format (#7985)
* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* ruff fix

---------

Co-authored-by: Sam <bhaskarmdutt@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-05 09:40:38 -04:00
psychedelicious
a499b9f54e chore: bump version to v5.11.0rc2 2025-05-05 23:32:27 +10:00
psychedelicious
104505ea02 chore(ui): lint 2025-05-05 23:25:29 +10:00
psychedelicious
ee4002607c feat(ui): add UI to reset hf token 2025-05-05 23:25:29 +10:00
psychedelicious
fd20582cdd chore(ui): typegen 2025-05-05 23:25:29 +10:00
psychedelicious
43b0d07517 feat(api): add route to reset hf token 2025-05-05 23:25:29 +10:00
blessedcoolant
f83592a052 fix: deprecation warning in get_iso_timestemp 2025-05-05 11:45:30 +10:00
Mary Hipp
b3ee906749 add prompt validation to imagen3 graph 2025-05-01 13:02:13 -04:00
psychedelicious
5d69e9068a feat(ui): add ability to globally disable hotkeys
This will both hide the hotkey from the hotkey modal and override any other enabled status it has.
2025-05-01 10:50:34 -04:00
psychedelicious
a79136b058 fix(ui): always add selectModelsTab hotkey data to prevent unhandled exception while registering the hotkey handler 2025-05-01 10:50:34 -04:00
psychedelicious
944af4d4a9 feat(ui): show unsupported gen mode toasts as warnings intead of errors 2025-05-01 23:25:01 +10:00
psychedelicious
5e001be73a tidy(ui): remove excessive nav to mm buttons 2025-05-01 23:22:19 +10:00
psychedelicious
576a644b3a tidy(ui): modelpicker component 2025-05-01 23:22:19 +10:00
psychedelicious
703557c8a6 feat(ui): cleanup 2025-05-01 23:22:19 +10:00
psychedelicious
d59a53b3f9 feat(ui): simplify picker types 2025-05-01 23:22:19 +10:00
psychedelicious
7b8f78c2d9 fix(ui): focus bug w/ popvoer 2025-05-01 23:22:19 +10:00
psychedelicious
31ab9be79a feat(ui): iterate on picker 2025-05-01 23:22:19 +10:00
psychedelicious
5011fab85d fix(ui): restore FLUX Dev info popover to main model picker 2025-05-01 10:59:51 +10:00
psychedelicious
92bdb9fdcc chore(ui): remove unused exports 2025-05-01 10:59:51 +10:00
Mary Hipp
548e766c0b feat(ui): ability to disable generating with API models 2025-05-01 10:59:51 +10:00
Mary Hipp
ff897f74a1 send the list of reference images reversed to chatGPT so it matches displayed order 2025-04-30 15:56:38 -04:00
psychedelicious
3d29c996ed feat(ui): support img2img for chatgpt 4o w/ ref images 2025-04-30 13:39:05 +10:00
psychedelicious
42d57d1225 fix(ui): ref image layout 2025-04-30 13:39:05 +10:00
psychedelicious
193fa9395a fix(ui): match ref image model to main model when creating global ref image 2025-04-30 13:39:05 +10:00
psychedelicious
56cd839d5b feat(ui): support for ref images for chatgpt on canvas 2025-04-30 13:39:05 +10:00
ubansi
7b446ee40d docs: fix Contribute node import error
When I followed the Contribute Node documentation, I encountered an import error.
This commit fixes the error, which will help reduce debugging time for all future contributors.
2025-04-29 21:03:00 -04:00
Mary Hipp Rogers
17027c4070 Maryhipp/chatgpt UI (#7969)
* add GPTimage1 as allowed base model

* fix for non-disabled inpaint layers

* lots of boilerplate for adding gpt-image base model and disabling things along with imagen

* handle gpt-image dimensions

* build graph for gpt-image

* lint

* feat(ui): make chatgpt model naming consistent

* feat(ui): graph builder naming

* feat(ui): disable img2img for imagen3

* feat(ui): more naming

* feat(ui): support presigned url prefetch

* feat(ui): disable neg prompt for chatgpt

* docs(ui): update docstring

* feat(ui): fix graph building issues for chatgpt

* fix(ui): node ids for chatgpt/imagen

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-04-29 09:38:03 -04:00
psychedelicious
13d44f47ce chore(ui): prettier 2025-04-29 09:12:49 +10:00
psychedelicious
550fbdeb1c fix(ui): more types fixes 2025-04-29 09:12:49 +10:00
psychedelicious
a01cd7c497 fix(ui): add chatgpt-4o to zod schemas that need to match autogenerated types 2025-04-29 09:12:49 +10:00
Mary Hipp
c54afd600c typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
4f911a0ea8 typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
fb91f48722 change base model for chatGPT 4o 2025-04-29 09:12:49 +10:00
psychedelicious
69db60a614 fix(ui): toast typo 2025-04-29 06:56:36 +10:00
Mary Hipp
c6d7f951aa typegen 2025-04-28 15:39:11 -04:00
Mary Hipp
04c005284c add gpt-image to possible base model types 2025-04-28 15:39:11 -04:00
psychedelicious
2d7f9697bf chore(ui): lint 2025-04-28 13:31:26 -04:00
psychedelicious
ae530492a2 chore(ui): typegen 2025-04-28 13:31:26 -04:00
psychedelicious
87ed1e3b6d feat(ui): do not allow imagen3 nodes in published workflows 2025-04-28 13:31:26 -04:00
psychedelicious
cc54466db9 fix(nodes): default value for UIConfigBase.tags 2025-04-28 13:31:26 -04:00
psychedelicious
cbdafe7e38 feat(nodes): allow node clobbering 2025-04-28 13:31:26 -04:00
psychedelicious
112cb76174 fix: random seed for edit mode imagen 2025-04-28 13:31:26 -04:00
psychedelicious
e56d41ab99 feat: rip out enhance prompt as toggleable option, imagen always randomizes seed 2025-04-28 13:31:26 -04:00
psychedelicious
273dfd86ab fix(ui): upscale builder 2025-04-28 13:31:26 -04:00
psychedelicious
871271fde5 feat(ui): rough out imagen3 support for canvas 2025-04-28 13:31:26 -04:00
psychedelicious
14944872c4 feat(mm): add model taxonomy for API models & Imagen3 as base model type 2025-04-28 13:31:26 -04:00
psychedelicious
07bcf3c446 feat(ui): port bbox select to native select 2025-04-28 13:31:26 -04:00
psychedelicious
8ed5585285 feat(nodes): move output metadata to BaseInvocationOutput 2025-04-28 09:19:43 -04:00
psychedelicious
5ce226a467 chore(ui): typegen 2025-04-28 09:19:43 -04:00
Mary Hipp
c64f20a72b remove output_metdata from schema 2025-04-28 09:19:43 -04:00
Mary Hipp
0c9c10a03a update schema 2025-04-28 09:19:43 -04:00
Mary Hipp
4a0df6b865 add optional output_metadata to baseinvocation 2025-04-28 09:19:43 -04:00
psychedelicious
ba165572bf chore: bump version to v5.11.0rc1 2025-04-28 10:10:50 +10:00
psychedelicious
c3d6a10603 fix(ui): handle minor breaking typing change from serialize-error 2025-04-28 09:53:08 +10:00
psychedelicious
4efc86299d fix(ui): type error in SettingsUpsellMenuItem 2025-04-28 09:53:08 +10:00
psychedelicious
e8c7cf63fd fix(ui): type error in canvas worker 2025-04-28 09:53:08 +10:00
psychedelicious
698b034190 chore(ui): bump deps 2025-04-28 09:53:08 +10:00
psychedelicious
3988128c40 feat(ui): add _all_ image outputs to gallery (including collections) 2025-04-28 09:49:04 +10:00
psychedelicious
c768f47365 fix(ui): dnd autoscroll in scrollable containers 2025-04-28 09:46:38 +10:00
psychedelicious
19a63abc54 fix(ui): hide file size on model picker when it is zero 2025-04-23 17:45:09 +10:00
psychedelicious
75ec36bf9a chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
d802f8e7fb feat(ui): disable search when no options 2025-04-23 17:45:09 +10:00
psychedelicious
6873e0308d feat(ui): custom fallback for model picker when no models installed 2025-04-23 17:45:09 +10:00
psychedelicious
66eb73088e feat(ui): rename user-provided extra ctx for picker from ctx to extra to be less confusing 2025-04-23 17:45:09 +10:00
psychedelicious
ed81a13eb4 docs(ui): add some comments for picker 2025-04-23 17:45:09 +10:00
psychedelicious
fbc1aae52d feat(ui): more flexible fallbacks for model picker 2025-04-23 17:45:09 +10:00
psychedelicious
ba42c3e63f feat(ui): tooltip for compact/full model picker view 2025-04-23 17:45:09 +10:00
psychedelicious
b24e820aa0 fix(ui): flash of "select a model" when changing model 2025-04-23 17:45:09 +10:00
psychedelicious
e8f6b3b77a feat(ui): split out mainmodelpicker component 2025-04-23 17:45:09 +10:00
psychedelicious
8f13518c97 feat(ui): add clear search button to model combobox 2025-04-23 17:45:09 +10:00
psychedelicious
6afbc12074 feat(ui): when no model bases selected, show all models 2025-04-23 17:45:09 +10:00
psychedelicious
6b0a56ceb9 chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
ca92497e52 feat(ui): remove description from model pciker for now 2025-04-23 17:45:09 +10:00
psychedelicious
97d45ceaf2 feat(ui): model picker filter buttons 2025-04-23 17:45:09 +10:00
psychedelicious
aeb3841a6f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
c14d33d3c1 tweak(ui): remove bg on ModelImage fallback 2025-04-23 17:45:09 +10:00
psychedelicious
676e59e072 chore(ui): bump react-resizable-panels to latest
This resolves a bug where SVG elements were ignored when checking when cursor is over a resize handle
2025-04-23 17:45:09 +10:00
psychedelicious
e7dcb6a03f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
fb95b7cc2b feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
015dc3ac0d feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
9d8a71b362 feat(ui): genericizing picker 2025-04-23 17:45:09 +10:00
psychedelicious
2eb212f393 feat(ui): onSelectId -> onSelectById 2025-04-23 17:45:09 +10:00
psychedelicious
34b268c15c feat(ui): use context for stable picker state 2025-04-23 17:45:09 +10:00
psychedelicious
9a203a64dc feat(ui): render picker in portal 2025-04-23 17:45:09 +10:00
psychedelicious
d80004e056 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
de32ed23a7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
5aed2b315d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
48db6cfc4f feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
aa7c5c281a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
87aeb7f889 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
3b3d6e413a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b6432f2de3 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
9d0a28ccae feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c3bf0a3277 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b516610c1e feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
677e717cd7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c52584e057 feat(ui): simplify ScrollableContent 2025-04-23 17:45:09 +10:00
psychedelicious
b6767441db feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
8745dbe67d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
a565d9473e feat(ui): add useStateImperative 2025-04-23 17:45:09 +10:00
psychedelicious
4dbf07c3e0 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
f6eb4d9a6b feat(ui): toast on select for demo purposes 2025-04-23 17:45:09 +10:00
psychedelicious
5037967b82 feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
4930ba48ce feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
40d2092256 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
d2e9237740 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
b191b706c1 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
4d0f760ec8 chore(ui): bump cmdk to latest 2025-04-23 17:45:09 +10:00
psychedelicious
65cda5365a feat(ui): remove go to mm button from node fields 2025-04-23 17:45:09 +10:00
psychedelicious
1f2d1d086f feat(ui): add <NavigateToModelManagerButton /> to model comboboxes everywhere 2025-04-23 17:45:09 +10:00
psychedelicious
418f3c3f19 feat(ui): abstract out workflow editor model combobox, ensure consistent ui for all model fields 2025-04-23 17:45:09 +10:00
psychedelicious
72173e284c fix(ui): useModelCombobox should use null for no value instead of undefined
This fixes an issue where the refiner combobox doesn't clear itself visually when clicking the little X icon to clear the selection.
2025-04-23 17:45:09 +10:00
psychedelicious
9cc13556aa feat(ui): accept callback to override navigate to model manager functionality
If provided, `<NavigateToModelManagerButton />` will render, even if `disabledTabs` includes "models". If provided, `<NavigateToModelManagerButton />` will run the callback instead of switching tabs within the studio.

The button's tooltip is now just "Manage Models" and its icon is the same as the model manager tab's icon ([CUBE!](https://www.youtube.com/watch?v=4aGDCE6Nrz0)).
2025-04-23 17:45:09 +10:00
psychedelicious
298444f2bc chore: bump version to v5.10.1 2025-04-19 00:05:02 +10:00
psychedelicious
deb1984289 fix(mm): disable new model probe API
There is a subtle change in behaviour with the new model probe API.

Previously, checks for model types was done in a specific order. For example, we did all main model checks before LoRA checks.

With the new API, the order of checks has changed. Check ordering is as follows:
- New API checks are run first, then legacy API checks.
- New API checks categorized by their speed. When we run new API checks, we sort them from fastest to slowest, and run them in that order. This is a performance optimization.

Currently, LoRA and LLaVA models are the only model types with the new API. Checks for them are thus run first.

LoRA checks involve checking the state dict for presence of keys with specific prefixes. We expect these keys to only exist in LoRAs.

It turns out that main models may have some of these keys.

For example, this model has keys that match the LoRA prefix `lora_te_`: https://civitai.com/models/134442/helloyoung25d

Under the old probe, we'd do the main model checks first and correctly identify this as a main model. But with the new setup, we do the LoRA check first, and those pass. So we import this model as a LoRA.

Thankfully, the old probe still exists. For now, the new probe is fully disabled. It was only called in one spot.

I've also added the example affected model as a test case for the model probe. Right now, this causes the test to fail, and I've marked the test as xfail. CI will pass.

Once we enable the new API again, the xfail will pass, and CI will fail, and we'll be reminded to update the test.
2025-04-18 22:44:10 +10:00
psychedelicious
814406d98a feat(mm): siglip model loading supports partial loading
In the previous commit, the LLaVA model was updated to support partial loading.

In this commit, the SigLIP model is updated in the same way.

This model is used for FLUX Redux. It's <4GB and only ever run in isolation, so it won't benefit from partial loading for the vast majority of users. Regardless, I think it is best if we make _all_ models work with partial loading.

PS: I also fixed the initial load dtype issue, described in the prev commit. It's probably a non-issue for this model, but we may as well fix it.
2025-04-18 10:12:03 +10:00
psychedelicious
c054501103 feat(mm): llava model loading supports partial loading; fix OOM crash on initial load
The model manager has two types of model cache entries:
- `CachedModelOnlyFullLoad`: The model may only ever be loaded and unloaded as a single object.
- `CachedModelWithPartialLoad`: The model may be partially loaded and unloaded.

Partial loaded is enabled by overwriting certain torch layer classes, adding the ability to autocast the layer to a device on-the-fly. See `CustomLinear` for an example.

So, to take advantage of partial loading and be cached as a `CachedModelWithPartialLoad`, the model must inherit from `torch.nn.Module`.

The LLaVA classes provided by `transformers` do inherit from `torch.nn.Module`, but we wrap those classes in a separate class called `LlavaOnevisionModel`. The wrapper encapsulate both the LLaVA model and its "processor" - a lightweight class that prepares model inputs like text and images.

While it is more elegant to encapsulate both model and processor classes in a single entity, this prevents the model cache from enabling partial loading for the chunky vLLM model.

Fixing this involved a few changes.
- Update the `LlavaOnevisionModelLoader` class to operate on the vLLM model directly, instead the `LlavaOnevisionModel` wrapper class.
- Instantiate the processor directly in the node. The processor is lightweight and does its business on the CPU. We don't need to worry about caching in the model manager.
- Remove caching support code from the `LlavaOnevisionModel` wrapper class. It's not needed, because we do not cache this class. The class now only handles running the models provided to it.
- Rename `LlavaOnevisionModel` to `LlavaOnevisionPipeline` to better represent its purpose.

These changes have a bonus effect of fixing an OOM crash when initially loading the models. This was most apparent when loading LLaVA 7B, which is pretty chunky.

The initial load is onto CPU RAM. In the old version of the loaders, we ignored the loader's target dtype for the initial load. Instead, we loaded the model at `transformers`'s "default" dtype of fp32.

LLaVA 7B is fp16 and weighs ~17GB. Loading as fp32 means we need double that amount (~34GB) of CPU RAM. Many users only have 32GB RAM, so this causes a _CPU_ OOM - which is a hard crash of the whole process.

With the updated loaders, the initial load logic now uses the target dtype for the initial load. LLaVA now needs the expected ~17GB RAM for its initial load.

PS: If we didn't make the accompanying partial loading changes, we still could have solved this OOM. We'd just need to pass the initial load dtype to the wrapper class and have it load on that dtype. But we may as well fix both issues.

PPS: There are other models whose model classes are wrappers around a torch module class, and thus cannot be partially loaded. However, these models are typically fairly small and/or are run only on their own, so they don't benefit as much from partial loading. It's the really big models (like LLaVA 7B) that benefit most from the partial loading.
2025-04-18 10:12:03 +10:00
psychedelicious
c1d819c7e5 feat(nodes): add get_absolute_path method to context.models API
Given a model config or path (presumably to a model), returns the absolute path to the model.

Check the next few commits for use-case.
2025-04-18 10:12:03 +10:00
psychedelicious
2a8e91f94d feat(ui): wrap JSON in dataviewer 2025-04-17 22:55:04 +10:00
174 changed files with 5521 additions and 3492 deletions

View File

@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
folder for more examples:
```py
from .cool_node import CoolInvocation
from .cool_node import ResizeInvocation
```
## Creating A New Invocation
@@ -69,7 +69,10 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.invocation_api import (
BaseInvocation,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -103,8 +106,12 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -128,8 +135,12 @@ image: ImageField = InputField(description="The input image")
Great. Now let us create our other inputs for `width` and `height`
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -163,8 +174,13 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -191,8 +207,14 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation('resize')
@@ -217,9 +239,15 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
So let's do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation):

View File

@@ -893,6 +893,12 @@ class HFTokenHelper:
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@classmethod
def reset_token(cls) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.logout()
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
@@ -915,3 +921,8 @@ async def do_hf_login(
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
async def reset_hf_token() -> HFTokenStatus:
return HFTokenHelper.reset_token()

View File

@@ -25,7 +25,7 @@ from typing import (
)
import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
@@ -72,13 +72,24 @@ class Classification(str, Enum, metaclass=MetaEnum):
Special = "special"
class Bottleneck(str, Enum, metaclass=MetaEnum):
"""
The bottleneck of an invocation.
- `Network`: The invocation's execution is network-bound.
- `GPU`: The invocation's execution is GPU-bound.
"""
Network = "network"
GPU = "gpu"
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @invocation decorator logic. Do not use this directly.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: str = Field(
@@ -100,6 +111,12 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
output_meta: Optional[dict[str, JsonValue]] = Field(
default=None,
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -235,6 +252,8 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
bottleneck: ClassVar[Bottleneck]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
@@ -256,6 +275,26 @@ class InvocationRegistry:
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
invocation_type = invocation.get_type()
node_pack = invocation.UIConfig.node_pack
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
if clobbered_invocation is not None:
# This should always be true - we just checked if the invocation type was in the set
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The invocation being clobbered is a core invocation
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
else:
# The invocation being clobbered is a custom invocation
logger.warning(
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
)
cls._invocation_classes.remove(clobbered_invocation)
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@@ -314,6 +353,15 @@ class InvocationRegistry:
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
output_type = output.get_type()
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
if clobbered_output is not None:
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
logger.warning(f'Overriding invocation output "{output_type}"')
cls._output_classes.remove(clobbered_output)
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@@ -322,6 +370,11 @@ class InvocationRegistry:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
"""Gets a map of all output types to their output classes."""
return {i.get_type(): i for i in cls.get_output_classes()}
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
@@ -347,6 +400,11 @@ class InvocationRegistry:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
@classmethod
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
"""Gets the output class for a given output type."""
return cls.get_outputs_map().get(output_type)
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
@@ -354,11 +412,12 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"use_cache",
"type",
"workflow",
"bottleneck",
}
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
class _Model(BaseModel):
@@ -438,6 +497,7 @@ def invocation(
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
bottleneck: Bottleneck = Bottleneck.GPU,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@@ -449,6 +509,7 @@ def invocation(
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@@ -460,25 +521,6 @@ def invocation(
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
@@ -504,6 +546,8 @@ def invocation(
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache
cls.bottleneck = bottleneck
# Add the invocation type to the model.
# You'd be tempted to just add the type field and rebuild the model, like this:
@@ -572,13 +616,9 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}

View File

@@ -61,6 +61,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
# endregion
# region Misc Field Types

View File

@@ -3,6 +3,7 @@ from typing import Literal, Optional
import torch
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -115,8 +116,14 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
assert isinstance(model, SiglipVisionModel)
model_abs_path = context.models.get_absolute_path(siglip_model_config)
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, SiglipImageProcessor)
siglip_pipeline = SigLipPipeline(processor, model)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)

View File

@@ -3,13 +3,14 @@ from typing import Any
import torch
from PIL.Image import Image
from pydantic import field_validator
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.util.devices import TorchDevice
@@ -54,10 +55,17 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)
model_config = context.models.get_config(self.vllm_model)
with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
with context.models.load(self.vllm_model).model_on_device() as (_, model):
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
model_abs_path = context.models.get_absolute_path(model_config)
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
model = LlavaOnevisionPipeline(model, processor)
output = model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),

View File

@@ -241,6 +241,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def build(
@@ -263,6 +264,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
credits=queue_item.credits,
)

View File

@@ -38,7 +38,6 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
@@ -647,10 +646,14 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
try:
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
except InvalidModelConfigException:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
# See commit message for details.
# try:
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
# except InvalidModelConfigException:
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -257,6 +257,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":

View File

@@ -21,6 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@@ -543,6 +544,30 @@ class ModelsInterface(InvocationContextInterface):
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
"""Gets the absolute path for a given model config or path.
For example, if the model's path is `flux/main/FLUX Dev.safetensors`, and the models path is
`/home/username/InvokeAI/models`, this method will return
`/home/username/InvokeAI/models/flux/main/FLUX Dev.safetensors`.
Args:
config_or_path: The model config or path.
Returns:
The absolute path to the model.
"""
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
if model_path.is_absolute():
return model_path.resolve()
base_models_path = self._services.configuration.models_path
joined_path = base_models_path / model_path
resolved_path = joined_path.resolve()
return resolved_path
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:

View File

@@ -61,6 +61,10 @@ def get_openapi_func(
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
# Remove output_metadata that is only used on back-end from the schema
if "output_meta" in json_schema["properties"]:
json_schema["properties"].pop("output_meta")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema

View File

@@ -10,7 +10,7 @@ def get_timestamp() -> int:
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
return datetime.datetime.now(datetime.timezone.utc).isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:

View File

@@ -1,26 +1,15 @@
from pathlib import Path
from typing import Optional
import torch
from PIL.Image import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.backend.raw_model import RawModel
from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
class LlavaOnevisionModel(RawModel):
class LlavaOnevisionPipeline:
"""A wrapper for a LLaVA Onevision model + processor."""
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
self._vllm_model = vllm_model
self._processor = processor
@classmethod
def load_from_path(cls, path: str | Path):
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
return cls(vllm_model, processor)
def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
# TODO(ryand): Tune the max number of images that are useful for the model.
if len(images) > 3:
@@ -44,13 +33,3 @@ class LlavaOnevisionModel(RawModel):
# The output_str will include the prompt, so we extract the response.
response = output_str.split("assistant\n", 1)[1].strip()
return response
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._vllm_model)

View File

@@ -144,6 +144,7 @@ class ModelConfigBase(ABC, BaseModel):
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
@@ -600,6 +601,21 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
}
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
# API models are not stored on disk, so we can't match them.
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -667,6 +683,7 @@ AnyModelConfig = Annotated[
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -13,6 +13,12 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
# up matrix and down matrix have different ranks so we can't simply multiply them
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)

View File

@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from transformers import LlavaOnevisionForConditionalGeneration
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
@@ -23,6 +24,8 @@ class LlavaOnevisionModelLoader(ModelLoader):
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
model_path = Path(config.path)
model = LlavaOnevisionModel.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_path, local_files_only=True, torch_dtype=self._torch_dtype
)
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
return model

View File

@@ -1,13 +1,14 @@
from pathlib import Path
from typing import Optional
from transformers import SiglipVisionModel
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.SigLIP, format=ModelFormat.Diffusers)
@@ -23,6 +24,5 @@ class SigLIPModelLoader(ModelLoader):
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
model_path = Path(config.path)
model = SigLipPipeline.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
model = SiglipVisionModel.from_pretrained(model_path, local_files_only=True, torch_dtype=self._torch_dtype)
return model

View File

@@ -16,11 +16,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.taxonomy import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -51,8 +49,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
GroundingDinoPipeline,
SegmentAnythingPipeline,
DepthAnythingPipeline,
SigLipPipeline,
LlavaOnevisionModel,
),
):
return model.calc_size()

View File

@@ -26,7 +26,8 @@ class BaseModelType(str, Enum):
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
CogView4 = "cogview4"
# Kandinsky2_1 = "kandinsky-2.1"
Imagen3 = "imagen3"
ChatGPT4o = "chatgpt-4o"
class ModelType(str, Enum):
@@ -98,6 +99,7 @@ class ModelFormat(str, Enum):
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
Api = "api"
class SchedulerPredictionType(str, Enum):

View File

@@ -19,6 +19,7 @@ class LoRALayer(LoRALayerBase):
self.up = up
self.mid = mid
self.down = down
self.are_ranks_equal = up.shape[1] == down.shape[0]
@classmethod
def from_state_dict_values(
@@ -58,12 +59,42 @@ class LoRALayer(LoRALayerBase):
def _rank(self) -> int:
return self.down.shape[0]
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
"""
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
the LoRA weights have different ranks. This function handles the fusion of these differently sized
matrices.
"""
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
rank_diff = down.shape[0] / up.shape[1]
if rank_diff > 1:
rank_diff = down.shape[0] / up.shape[1]
w_down = down.chunk(int(rank_diff), dim=0)
for w_down_chunk in w_down:
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
else:
rank_diff = up.shape[1] / down.shape[0]
w_up = up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
return fused_lora
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
# up matrix and down matrix have different ranks so we can't simply multiply them
if not self.are_ranks_equal:
weight = self.fuse_weights(self.up, self.down)
return weight
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight

View File

@@ -20,6 +20,14 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_final_layer_linear.alpha
# lora_unet_final_layer_linear.lora_down.weight
# lora_unet_final_layer_linear.lora_up.weight
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
@@ -44,6 +52,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
@@ -65,6 +74,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
# Skip the final layer. This is incompatible with current model definition.
if layer_name.startswith("lora_unet_final_layer"):
continue
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict

View File

@@ -1,14 +1,9 @@
from pathlib import Path
from typing import Optional
import torch
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from invokeai.backend.raw_model import RawModel
class SigLipPipeline(RawModel):
class SigLipPipeline:
"""A wrapper for a SigLIP model + processor."""
def __init__(
@@ -19,25 +14,7 @@ class SigLipPipeline(RawModel):
self._siglip_processor = siglip_processor
self._siglip_model = siglip_model
@classmethod
def load_from_path(cls, path: str | Path):
siglip_model = SiglipVisionModel.from_pretrained(path, local_files_only=True)
assert isinstance(siglip_model, SiglipVisionModel)
siglip_processor = SiglipImageProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(siglip_processor, SiglipImageProcessor)
return cls(siglip_processor, siglip_model)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._siglip_model.to(device=device, dtype=dtype)
def encode_image(self, x: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
imgs = self._siglip_processor.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
encoded_x = self._siglip_model(**imgs.to(device=device, dtype=dtype)).last_hidden_state
return encoded_x
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._siglip_model)

View File

@@ -52,68 +52,68 @@
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@fontsource-variable/inter": "^5.1.0",
"@fontsource-variable/inter": "^5.2.5",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@nanostores/react": "^1.0.0",
"@reduxjs/toolkit": "2.7.0",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.5.3",
"@xyflow/react": "^12.6.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.10.0",
"i18next": "^23.15.1",
"i18next-http-backend": "^2.6.1",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.15",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.2.0",
"linkifyjs": "^4.2.0",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.1",
"lru-cache": "^11.1.0",
"mtwist": "^1.0.2",
"nanoid": "^5.0.7",
"nanostores": "^0.11.3",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
"nanoid": "^5.1.5",
"nanostores": "^1.0.1",
"new-github-issue-url": "^1.1.0",
"overlayscrollbars": "^2.11.1",
"overlayscrollbars-react": "^0.5.6",
"perfect-freehand": "^1.2.2",
"query-string": "^9.1.0",
"query-string": "^9.1.1",
"raf-throttle": "^2.0.6",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.2.9",
"react-error-boundary": "^4.0.13",
"react-hook-form": "^7.53.0",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-hook-form": "^7.56.1",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^15.0.2",
"react-icons": "^5.3.0",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.1.4",
"react-textarea-autosize": "^8.5.7",
"react-use": "^17.5.1",
"react-virtuoso": "^4.12.5",
"react-i18next": "^15.5.1",
"react-icons": "^5.5.0",
"react-redux": "9.2.0",
"react-resizable-panels": "^2.1.8",
"react-textarea-autosize": "^8.5.9",
"react-use": "^17.6.0",
"react-virtuoso": "^4.12.6",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-remember": "^5.2.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.4.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.8.0",
"stable-hash": "^0.0.4",
"use-debounce": "^10.0.3",
"serialize-error": "^12.0.0",
"socket.io-client": "^4.8.1",
"stable-hash": "^0.0.5",
"use-debounce": "^10.0.4",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"uuid": "^11.1.0",
"zod": "^3.24.3",
"zod-validation-error": "^3.4.0"
},
"peerDependencies": {
@@ -123,43 +123,43 @@
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.3.4",
"@storybook/addon-interactions": "^8.3.4",
"@storybook/addon-links": "^8.3.4",
"@storybook/addon-storysource": "^8.3.4",
"@storybook/manager-api": "^8.3.4",
"@storybook/react": "^8.3.4",
"@storybook/react-vite": "^8.5.5",
"@storybook/theming": "^8.3.4",
"@storybook/addon-essentials": "^8.6.12",
"@storybook/addon-interactions": "^8.6.12",
"@storybook/addon-links": "^8.6.12",
"@storybook/addon-storysource": "^8.6.12",
"@storybook/manager-api": "^8.6.12",
"@storybook/react": "^8.6.12",
"@storybook/react-vite": "^8.6.12",
"@storybook/theming": "^8.6.12",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.16.10",
"@types/node": "^22.15.1",
"@types/react": "^18.3.11",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^10.0.0",
"@vitejs/plugin-react-swc": "^3.8.0",
"@vitest/coverage-v8": "^3.0.6",
"@vitest/ui": "^3.0.6",
"concurrently": "^8.2.2",
"@vitejs/plugin-react-swc": "^3.9.0",
"@vitest/coverage-v8": "^3.1.2",
"@vitest/ui": "^3.1.2",
"concurrently": "^9.1.2",
"csstype": "^3.1.3",
"dpdm": "^3.14.0",
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.0",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.31.0",
"knip": "^5.50.5",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.4.1",
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.3.4",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
"rollup-plugin-visualizer": "^5.14.0",
"storybook": "^8.6.12",
"tsafe": "^1.8.5",
"type-fest": "^4.26.1",
"typescript": "^5.6.2",
"vite": "^6.1.0",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^6.3.3",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.0",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^5.1.4",
"vitest": "^3.0.6"
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "8"

File diff suppressed because it is too large Load Diff

View File

@@ -118,6 +118,8 @@
"error": "Error",
"error_withCount_one": "{{count}} error",
"error_withCount_other": "{{count}} errors",
"model_withCount_one": "{{count}} model",
"model_withCount_other": "{{count}} models",
"file": "File",
"folder": "Folder",
"format": "format",
@@ -138,6 +140,8 @@
"localSystem": "Local System",
"learnMore": "Learn More",
"modelManager": "Model Manager",
"noMatches": "No matches",
"noOptions": "No options",
"nodes": "Workflows",
"notInstalled": "Not $t(common.installed)",
"openInNewTab": "Open in New Tab",
@@ -171,6 +175,8 @@
"blue": "Blue",
"alpha": "Alpha",
"selected": "Selected",
"search": "Search",
"clear": "Clear",
"tab": "Tab",
"view": "View",
"edit": "Edit",
@@ -197,7 +203,11 @@
"column": "Column",
"value": "Value",
"label": "Label",
"systemInformation": "System Information"
"systemInformation": "System Information",
"compactView": "Compact View",
"fullView": "Full View",
"options_withCount_one": "{{count}} option",
"options_withCount_other": "{{count}} options"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -258,6 +268,7 @@
"status": "Status",
"total": "Total",
"time": "Time",
"credits": "Credits",
"pending": "Pending",
"in_progress": "In Progress",
"completed": "Completed",
@@ -768,6 +779,7 @@
"description": "Description",
"edit": "Edit",
"fileSize": "File Size",
"filterModels": "Filter models",
"fluxRedux": "FLUX Redux",
"height": "Height",
"huggingFace": "HuggingFace",
@@ -787,6 +799,7 @@
"hfTokenUnableToVerify": "Unable to Verify HF Token",
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
"hfTokenSaved": "HF Token Saved",
"hfTokenReset": "HF Token Reset",
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"imageEncoderModelId": "Image Encoder Model ID",
@@ -821,10 +834,12 @@
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"name": "Name",
"noModelsInstalled": "No Models Installed",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
"noModelsInstalledDesc1": "Install models with the",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching Models",
"noMatchingModels": "No matching models",
"noModelsInstalled": "No models installed",
"none": "none",
"path": "Path",
"pathToConfig": "Path To Config",
@@ -871,7 +886,8 @@
"installingXModels_one": "Installing {{count}} model",
"installingXModels_other": "Installing {{count}} models",
"skippingXDuplicates_one": ", skipping {{count}} duplicate",
"skippingXDuplicates_other": ", skipping {{count}} duplicates"
"skippingXDuplicates_other": ", skipping {{count}} duplicates",
"manageModels": "Manage Models"
},
"models": {
"addLora": "Add LoRA",
@@ -1093,6 +1109,7 @@
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "missing input",
@@ -1173,7 +1190,8 @@
"width": "Width",
"gaussianBlur": "Gaussian Blur",
"boxBlur": "Box Blur",
"staged": "Staged"
"staged": "Staged",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your <LinkComponent>account settings</LinkComponent> to upgrade."
},
"dynamicPrompts": {
"showDynamicPrompts": "Show Dynamic Prompts",
@@ -1312,6 +1330,8 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"

View File

@@ -5,6 +5,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
@@ -12,10 +13,13 @@ import type { CustomStarUi } from 'app/store/nanostores/customStarUI';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { $logo } from 'app/store/nanostores/logo';
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import { $projectId, $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
import { $store } from 'app/store/nanostores/store';
import { $toastMap } from 'app/store/nanostores/toastMap';
import { $whatsNew } from 'app/store/nanostores/whatsNew';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
@@ -29,6 +33,7 @@ import {
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES,
} from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { ToastConfig } from 'features/toast/toast';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@@ -45,6 +50,7 @@ interface Props extends PropsWithChildren {
token?: string;
config?: PartialAppConfig;
customNavComponent?: ReactNode;
accountSettingsLink?: string;
middleware?: Middleware[];
projectId?: string;
projectName?: string;
@@ -55,10 +61,16 @@ interface Props extends PropsWithChildren {
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;
logo?: ReactNode;
toastMap?: Record<string, ToastConfig>;
whatsNew?: ReactNode[];
workflowCategories?: WorkflowCategory[];
workflowTagCategories?: WorkflowTagCategory[];
workflowSortOptions?: WorkflowSortOption[];
loggingOverrides?: LoggingOverrides;
/**
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
}
const InvokeAIUI = ({
@@ -67,6 +79,7 @@ const InvokeAIUI = ({
token,
config,
customNavComponent,
accountSettingsLink,
middleware,
projectId,
projectName,
@@ -77,10 +90,13 @@ const InvokeAIUI = ({
socketOptions,
isDebugging = false,
logo,
toastMap,
workflowCategories,
workflowTagCategories,
workflowSortOptions,
loggingOverrides,
onClickGoToModelManager,
whatsNew,
}: Props) => {
useLayoutEffect(() => {
/*
@@ -169,6 +185,16 @@ const InvokeAIUI = ({
};
}, [customNavComponent]);
useEffect(() => {
if (accountSettingsLink) {
$accountSettingsLink.set(accountSettingsLink);
}
return () => {
$accountSettingsLink.set(undefined);
};
}, [accountSettingsLink]);
useEffect(() => {
if (openAPISchemaUrl) {
$openAPISchemaUrl.set(openAPISchemaUrl);
@@ -205,6 +231,36 @@ const InvokeAIUI = ({
};
}, [logo]);
useEffect(() => {
if (toastMap) {
$toastMap.set(toastMap);
}
return () => {
$toastMap.set(undefined);
};
}, [toastMap]);
useEffect(() => {
if (whatsNew) {
$whatsNew.set(whatsNew);
}
return () => {
$whatsNew.set(undefined);
};
}, [whatsNew]);
useEffect(() => {
if (onClickGoToModelManager) {
$onClickGoToModelManager.set(onClickGoToModelManager);
}
return () => {
$onClickGoToModelManager.set(undefined);
};
}, [onClickGoToModelManager]);
useEffect(() => {
if (workflowCategories) {
$workflowLibraryCategoriesOptions.set(workflowCategories);

View File

@@ -1,3 +1,4 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
@@ -6,11 +7,14 @@ import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
@@ -48,32 +52,50 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status: 'error',
title: 'Failed to build graph',
status,
title,
description,
});
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
@@ -89,7 +111,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}

View File

@@ -18,16 +18,24 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const state = getState();
const { prepend } = action.payload;
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}

View File

@@ -0,0 +1,3 @@
import { atom } from 'nanostores';
export const $accountSettingsLink = atom<string | undefined>(undefined);

View File

@@ -0,0 +1,3 @@
import { atom } from 'nanostores';
export const $onClickGoToModelManager = atom<(() => void) | undefined>(undefined);

View File

@@ -0,0 +1,4 @@
import type { ToastConfig } from 'features/toast/toast';
import { atom } from 'nanostores';
export const $toastMap = atom<Record<string, ToastConfig> | undefined>(undefined);

View File

@@ -0,0 +1,4 @@
import { atom } from 'nanostores';
import type { ReactNode } from 'react';
export const $whatsNew = atom<ReactNode[] | undefined>(undefined);

View File

@@ -145,7 +145,10 @@ const unserialize: UnserializeFunction = (data, key) => {
);
return transformed;
} catch (err) {
log.warn({ error: serializeError(err) }, `Error rehydrating slice "${key}", falling back to default initial state`);
log.warn(
{ error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
return persistConfig.initialState;
}
};

View File

@@ -28,7 +28,8 @@ export type AppFeature =
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll';
| 'cancelAndClearAll'
| 'chatGPT4oModels';
/**
* A disable-able Stable Diffusion feature
*/
@@ -83,6 +84,7 @@ export type AppConfig = {
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
shouldShowCredits: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

File diff suppressed because it is too large Load Diff

View File

@@ -38,7 +38,7 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
}, [optionsFilter, getIsDisabled, modelConfigs, shouldShowModelDescriptions]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
[options, selectedModel]
);

View File

@@ -1,6 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { memo } from 'react';
/**
* A typed version of React.memo, useful for components that take generics.
*/
export const typedMemo: <T>(c: T) => T = memo;
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
component: T,
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
) => T & { displayName?: string } = memo;

View File

@@ -24,6 +24,7 @@ export const CanvasAddEntityButtons = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -52,6 +53,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
isDisabled={!isInpaintLayerEnabled}
>
{t('controlLayers.inpaintMask')}
</Button>

View File

@@ -25,6 +25,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Menu>
@@ -46,7 +47,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={!isRegionalGuidanceEnabled}>

View File

@@ -0,0 +1,63 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
};
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
if (!modelConfig) {
return;
}
onChangeModel(modelConfig);
},
[onChangeModel]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChangeModel,
selectedModel,
getIsDisabled,
isLoading,
});
return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
);
});
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';

View File

@@ -61,7 +61,7 @@ export const IPAdapterImagePreview = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" />
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}

View File

@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -4,29 +4,26 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
};
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
@@ -71,4 +68,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
);
});
IPAdapterModel.displayName = 'IPAdapterModel';
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';

View File

@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import {
modelConfigsAdapterSelectors,
selectMainModelConfig,
selectModelConfigsQuery,
} from 'services/api/endpoints/models';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
}
);
export const selectDefaultRefImageConfig = createSelector(
selectMainModelConfig,
selectModelConfigsQuery,
selectBase,
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
if (selectedMainModel?.base === 'chatgpt-4o') {
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}
const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
if (model.base === 'flux') {
ipAdapter.clipVisionModel = 'ViT-L';
}
}
return ipAdapter;
}
);
/**
* Selects the default IP adapter configuration based on the model configurations and the base.
*
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
export const useAddGlobalReferenceImage = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
const func = useCallback(() => {
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
const overrides = { ipAdapter: deepClone(defaultRefImage) };
dispatch(referenceImageAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
}, [defaultRefImage, dispatch]);
return func;
};

View File

@@ -41,7 +41,7 @@ export const useCopyLayerToClipboard = () => {
});
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemCopyingLayer'),
@@ -82,7 +82,7 @@ export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to save canvas to gallery');
log.error({ error: serializeError(error as Error) }, 'Failed to save canvas to gallery');
toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
}
}, [canvasManager.compositor, canvasManager.stateApi, clipboard, region, t]);

View File

@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHook
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
@@ -198,7 +198,7 @@ export const useNewRegionalReferenceImageFromBbox = () => {
export const useNewGlobalReferenceImageFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultIPAdapter = useAppSelector(selectDefaultRefImageConfig);
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO) => {

View File

@@ -1,5 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsCogView4, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -8,23 +13,25 @@ import { assert } from 'tsafe';
export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3;
case 'regional_guidance':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'inpaint_mask':
return true;
return !isImagen3 && !isChatGPT4o;
case 'raster_layer':
return true;
return !isImagen3 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4]);
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -41,7 +41,7 @@ export const useSaveLayerToAssets = () => {
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemSavingLayer'),

View File

@@ -519,7 +519,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
} catch (error) {
this.log.error({ rasterizeArgs, error: serializeError(error) }, 'Failed to rasterize entity');
this.log.error({ rasterizeArgs, error: serializeError(error as Error) }, 'Failed to rasterize entity');
throw error;
} finally {
this.manager.stateApi.$rasterizingAdapter.set(null);

View File

@@ -346,7 +346,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
// If the user is not holding shift, the transform is retaining aspect ratio. It's not possible to snap to the grid
// in this case, because that would change the aspect ratio. So, we only snap to the grid when shift is held.
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getGridSize() : 1;
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getPositionGridSize() : 1;
// We need to snap the anchor to the selected grid size, but the positions provided to this callback are absolute,
// scaled coordinates. They need to be converted to stage coordinates, snapped, then converted back to absolute
@@ -464,7 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -498,7 +498,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -523,7 +523,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
onDragMove = () => {
// Snap the interaction rect to the grid
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
this.konva.proxyRect.x(roundToMultiple(this.konva.proxyRect.x(), gridSize));
this.konva.proxyRect.y(roundToMultiple(this.konva.proxyRect.y(), gridSize));

View File

@@ -112,7 +112,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
return;
}
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url));
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));

View File

@@ -493,7 +493,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
* Gets the _positional_ grid size for the current canvas. Note that this is not the same as bbox grid size, which is
* based on the currently-selected model.
*/
getGridSize = (): number => {
getPositionGridSize = (): number => {
const snapToGrid = this.getSettings().snapToGrid;
if (!snapToGrid) {
return 1;

View File

@@ -4,8 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import Konva from 'konva';
import { noop } from 'lodash-es';
import { atom } from 'nanostores';
@@ -178,6 +180,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
// Listen for the bbox overlay setting to update the overlay's visibility
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBboxOverlay, this.render));
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
// Update on busy state changes
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
}
@@ -218,12 +223,25 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
this.syncOverlay();
const model = this.manager.stateApi.runSelector(selectModel);
this.konva.transformer.setAttrs({
listening: tool === 'bbox',
enabledAnchors: tool === 'bbox' ? ALL_ANCHORS : NO_ANCHORS,
enabledAnchors: this.getEnabledAnchors(tool, model),
});
};
getEnabledAnchors = (tool: Tool, model?: ModelIdentifierField | null): string[] => {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}
return ALL_ANCHORS;
};
syncOverlay = () => {
const bboxOverlay = this.manager.stateApi.getSettings().bboxOverlay;
@@ -251,7 +269,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
onDragMove = () => {
// The grid size here is the _position_ grid size, not the _dimension_ grid size - it is not constratined by the
// currently-selected model.
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const bbox = this.manager.stateApi.getBbox();
const bboxRect: Rect = {
...bbox.rect,

View File

@@ -476,15 +476,24 @@ export function getImageDataTransparency(imageData: ImageData): Transparency {
/**
* Loads an image from a URL and returns a promise that resolves with the loaded image element.
* @param src The image source URL
* @param fetchUrlFirst Whether to fetch the image's URL first, assuming the provided `src` will redirect to a different URL. This addresses an issue where CORS headers are dropped during a redirect.
* @returns A promise that resolves with the loaded image element
*/
export function loadImage(src: string): Promise<HTMLImageElement> {
export async function loadImage(src: string, fetchUrlFirst?: boolean): Promise<HTMLImageElement> {
const authToken = $authToken.get();
let url = src;
if (authToken && fetchUrlFirst) {
const response = await fetch(`${src}?url_only=true`, { credentials: 'include' });
const data = await response.json();
url = data.url;
}
return new Promise((resolve, reject) => {
const imageElement = new Image();
imageElement.onload = () => resolve(imageElement);
imageElement.onerror = (error) => reject(error);
imageElement.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
imageElement.src = src;
imageElement.src = url;
});
}

View File

@@ -10,12 +10,12 @@ export type Extents = {
/**
* Get the bounding box of an image.
* @param buffer The ArrayBuffer of the image to get the bounding box of.
* @param buffer The ArrayBufferLike of the image to get the bounding box of.
* @param width The width of the image.
* @param height The height of the image.
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/
const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height: number): Extents | null => {
const getImageDataBboxArrayBufferLike = (buffer: ArrayBufferLike, width: number, height: number): Extents | null => {
let minX = width;
let minY = height;
let maxX = -1;
@@ -50,7 +50,7 @@ const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height:
export type GetBboxTask = {
type: 'get_bbox';
data: { id: string; buffer: ArrayBuffer; width: number; height: number };
data: { id: string; buffer: ArrayBufferLike; width: number; height: number };
};
type TaskWithTimestamps<T extends Record<string, unknown>> = T & { started: number | null; finished: number | null };
@@ -95,7 +95,7 @@ function processNextTask() {
// Process the task
if (task.type === 'get_bbox') {
const { buffer, width, height, id } = task.data;
const extents = getImageDataBboxArrayBuffer(buffer, width, height);
const extents = getImageDataBboxArrayBufferLike(buffer, width, height);
const result: ExtentsResult = {
type: 'extents',
data: { id, extents },

View File

@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import { isEqual, merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type {
ApiModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
@@ -67,7 +68,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -76,6 +77,7 @@ import {
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFLUXRedux,
@@ -644,7 +646,10 @@ export const canvasSlice = createSlice({
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
EntityIdentifierPayload<
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
'reference_image'
>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
@@ -652,14 +657,36 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
const oldModel = entity.ipAdapter.model;
// First set the new model
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
if (!entity.ipAdapter.model) {
return;
}
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
if (isEqual(oldModel, entity.ipAdapter.model)) {
// Nothing changed, so we don't need to do anything
return;
}
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
// When we switch the model, we keep the image the same, but change the other parameters.
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
// Switching to chatgpt-4o ref image
entity.ipAdapter = {
...initialChatGPT4oReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
@@ -668,17 +695,13 @@ export const canvasSlice = createSlice({
return;
}
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
if (entity.ipAdapter.model.type === 'ip_adapter') {
// Switching to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
@@ -686,6 +709,7 @@ export const canvasSlice = createSlice({
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
return;
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
@@ -1139,7 +1163,21 @@ export const canvasSlice = createSlice({
syncScaledSize(state);
},
bboxChangedFromCanvas: (state, action: PayloadAction<IRect>) => {
state.bbox.rect = action.payload;
const newBboxRect = action.payload;
const oldBboxRect = state.bbox.rect;
state.bbox.rect = newBboxRect;
if (newBboxRect.width === oldBboxRect.width && newBboxRect.height === oldBboxRect.height) {
return;
}
const oldAspectRatio = state.bbox.aspectRatio.value;
const newAspectRatio = newBboxRect.width / newBboxRect.height;
if (oldAspectRatio === newAspectRatio) {
return;
}
// TODO(psyche): Figure out a way to handle this without resetting the aspect ratio on every change.
// This action is dispatched when the user resizes or moves the bbox from the canvas. For now, when the user
@@ -1198,6 +1236,40 @@ export const canvasSlice = createSlice({
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
state.bbox.rect.height = 768;
} else if (id === '4:3') {
state.bbox.rect.width = 1280;
state.bbox.rect.height = 896;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '3:4') {
state.bbox.rect.width = 896;
state.bbox.rect.height = 1280;
} else if (id === '9:16') {
state.bbox.rect.width = 768;
state.bbox.rect.height = 1408;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '3:2') {
state.bbox.rect.width = 1536;
state.bbox.rect.height = 1024;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '2:3') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1536;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
@@ -1670,6 +1742,13 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3' || base === 'chatgpt-4o') {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
}
syncScaledSize(state);
}
});
@@ -1802,6 +1881,10 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}
if (state.bbox.scaleMethod === 'auto') {
// Sync both aspect ratio and size
const { width, height } = state.bbox.rect;

View File

@@ -380,6 +380,8 @@ export const selectIsSDXL = createParamsSelector((params) => params.model?.base
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);

View File

@@ -245,6 +245,18 @@ const zFLUXReduxConfig = z.object({
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zChatGPT4oReferenceImageConfig = z.object({
type: z.literal('chatgpt_4o_reference_image'),
image: zImageWithDims.nullable(),
/**
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -254,15 +266,19 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
export const isIPAdapterConfig = (config: CanvasReferenceImageState['ipAdapter']): config is IPAdapterConfig =>
config.type === 'ip_adapter';
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']): config is FLUXReduxConfig =>
config.type === 'flux_redux';
export const isChatGPT4oReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
@@ -387,9 +403,18 @@ export type StagingAreaImage = {
offsetY: number;
};
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
const zCanvasState = z.object({
_version: z.literal(3),

View File

@@ -7,6 +7,7 @@ import type {
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
FLUXReduxConfig,
@@ -77,6 +78,11 @@ export const initialFLUXRedux: FLUXReduxConfig = {
model: null,
imageInfluence: 'highest',
};
export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
type: 'chatgpt_4o_reference_image',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -4,7 +4,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
export interface DynamicPromptsState {

View File

@@ -1,5 +1,5 @@
import type { FlexProps } from '@invoke-ai/ui-library';
import { Box, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
import { Box, chakra, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { useClipboard } from 'common/hooks/useClipboard';
import { Formatter } from 'fracturedjsonjs';
@@ -26,6 +26,8 @@ const overlayscrollbarsOptions = getOverlayScrollbarsParams({
overflowY: 'scroll',
}).options;
const ChakraPre = chakra('pre');
const DataViewer = (props: Props) => {
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions, ...rest } = props;
const dataString = useMemo(() => (isString(data) ? data : formatter.Serialize(data)) ?? '', [data]);
@@ -51,7 +53,7 @@ const DataViewer = (props: Props) => {
<Flex bg="base.800" borderRadius="base" flexGrow={1} w="full" h="full" position="relative" {...rest}>
<Box position="absolute" top={0} left={0} right={0} bottom={0} overflow="auto" p={2} fontSize="sm">
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayscrollbarsOptions}>
<pre>{dataString}</pre>
<ChakraPre whiteSpace="pre-wrap">{dataString}</ChakraPre>
</OverlayScrollbarsComponent>
</Box>
<Flex position="absolute" top={0} insetInlineEnd={0} p={2}>

View File

@@ -1,6 +1,6 @@
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
@@ -116,7 +116,7 @@ export const createNewCanvasEntityFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
break;
@@ -238,7 +238,7 @@ export const newCanvasFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(canvasReset());
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));

View File

@@ -58,7 +58,7 @@ const LoRASelect = () => {
const noOptionsMessage = useCallback(() => t('models.noMatchingLoRAs'), [t]);
return (
<FormControl isDisabled={!options.length}>
<FormControl isDisabled={!options.length} gap={2}>
<InformationalPopover feature="lora">
<FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover>

View File

@@ -12,63 +12,45 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
import { UNAUTHORIZED_TOAST_ID } from 'services/events/onModelInstallError';
import {
useGetHFTokenStatusQuery,
useResetHFTokenMutation,
useSetHFTokenMutation,
} from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
export const HFToken = () => {
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const [token, setToken] = useState('');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [trigger, { isLoading, isUninitialized }] = useSetHFTokenMutation();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
if (res === 'valid') {
setToken('');
toast({
id: UNAUTHORIZED_TOAST_ID,
title: t('modelManager.hfTokenSaved'),
status: 'success',
duration: 3000,
});
}
});
}, [t, token, trigger]);
const error = useMemo(() => {
if (!currentData || isUninitialized || isLoading) {
return null;
switch (currentData) {
case 'invalid':
return t('modelManager.hfTokenInvalidErrorMessage');
case 'unknown':
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
case 'valid':
case undefined:
return null;
default:
assert<Equals<never, typeof currentData>>(false, 'Unexpected HF token status');
}
if (currentData === 'invalid') {
return t('modelManager.hfTokenInvalidErrorMessage');
}
if (currentData === 'unknown') {
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
}
return null;
}, [currentData, isLoading, isUninitialized, t]);
}, [currentData, t]);
if (!currentData || currentData === 'valid') {
if (!currentData) {
return null;
}
return (
<Flex borderRadius="base" w="full">
<FormControl isInvalid={!isUninitialized && Boolean(error)} orientation="vertical">
<FormControl isInvalid={Boolean(error)} orientation="vertical">
<FormLabel>{t('modelManager.hfTokenLabel')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
{error && <SetHFTokenInput />}
{!error && <ResetHFTokenButton />}
<FormHelperText>
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
</FormHelperText>
@@ -77,3 +59,73 @@ export const HFToken = () => {
</Flex>
);
};
const PLACEHOLDER_TOKEN = Array.from({ length: 37 }, () => 'a').join('');
const ResetHFTokenButton = memo(() => {
const { t } = useTranslation();
const [resetHFToken, { isLoading }] = useResetHFTokenMutation();
const onClick = useCallback(() => {
resetHFToken()
.unwrap()
.then(() => {
toast({
title: t('modelManager.hfTokenReset'),
status: 'info',
});
});
}, [resetHFToken, t]);
return (
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={PLACEHOLDER_TOKEN} isDisabled />
<Button onClick={onClick} size="sm" isLoading={isLoading}>
{t('common.reset')}
</Button>
</Flex>
);
});
ResetHFTokenButton.displayName = 'ResetHFTokenButton';
const SetHFTokenInput = memo(() => {
const { t } = useTranslation();
const [token, setToken] = useState('');
const [trigger, { isLoading }] = useSetHFTokenMutation();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
switch (res) {
case 'valid':
setToken('');
toast({
title: t('modelManager.hfTokenSaved'),
status: 'success',
});
break;
case 'invalid':
case 'unknown':
default:
toast({
title: t('modelManager.hfTokenUnableToVerify'),
status: 'error',
});
break;
}
});
}, [t, token, trigger]);
return (
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
);
});
SetHFTokenInput.displayName = 'SetHFTokenInput';

View File

@@ -1,11 +1,10 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HFToken } from './HFToken';
import { HuggingFaceResults } from './HuggingFaceResults';
@@ -16,7 +15,6 @@ export const HuggingFaceForm = memo(() => {
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModel();
@@ -68,7 +66,7 @@ export const HuggingFaceForm = memo(() => {
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{currentData !== 'valid' && <HFToken />}
{isHFTokenEnabled && <HFToken />}
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);

View File

@@ -7,7 +7,7 @@ type Props = {
base: BaseModelType;
};
const BASE_COLOR_MAP: Record<BaseModelType, string> = {
export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
@@ -15,12 +15,14 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
cogview4: 'orange',
cogview4: 'red',
imagen3: 'pink',
'chatgpt-4o': 'pink',
};
const ModelBaseBadge = ({ base }: Props) => {
return (
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
<Badge flexGrow={0} flexShrink={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
{MODEL_TYPE_SHORT_MAP[base]}
</Badge>
);

View File

@@ -17,6 +17,7 @@ const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
bnb_quantized_int8b: 'bnb_quantized_int8b',
bnb_quantized_nf4b: 'quantized',
gguf_quantized: 'gguf',
api: 'api',
};
const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
@@ -30,6 +31,7 @@ const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
bnb_quantized_int8b: 'base',
bnb_quantized_nf4b: 'base',
gguf_quantized: 'base',
api: 'base',
};
const ModelFormatBadge = ({ format }: Props) => {

View File

@@ -15,7 +15,6 @@ const ModelImage = ({ image_url }: Props) => {
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
bg="base.650"
borderRadius="base"
alignItems="center"
justifyContent="center"

View File

@@ -1,10 +1,12 @@
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -23,6 +25,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
@@ -57,6 +61,8 @@ import {
isImageFieldInputTemplate,
isImageGeneratorFieldInputInstance,
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
@@ -394,6 +400,20 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen3ModelFieldInputTemplate(template)) {
if (!isImagen3ModelFieldInputInstance(field)) {
return null;
}
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
}
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isColorFieldInputTemplate(template)) {
if (!isColorFieldInputInstance(field)) {
return null;

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedMode
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
@@ -34,32 +29,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedMo
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
@@ -35,32 +30,15 @@ const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedMo
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
@@ -35,32 +30,15 @@ const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const ChatGPT4oModelFieldInputComponent = (
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldChatGPT4oModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ChatGPT4oModelFieldInputComponent);

View File

@@ -1,16 +1,13 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
ControlLoRAModelFieldInputInstance,
ControlLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
import { type ControlLoRAModelConfig, isControlLoRAModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -18,12 +15,10 @@ type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRA
const ControlLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlLoRAModel();
const _onChange = useCallback(
const onChange = useCallback(
(value: ControlLoRAModelConfig | null) => {
if (!value) {
return;
@@ -38,32 +33,15 @@ const ControlLoRAModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isControlLoRAModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
@@ -17,7 +15,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: ControlNetModelConfig | null) => {
if (!value) {
return;
@@ -33,25 +31,14 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const FluxReduxModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const FluxReduxModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFie
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
@@ -34,27 +29,15 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
@@ -17,7 +15,7 @@ const IPAdapterModelFieldInputComponent = (
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
if (!value) {
return;
@@ -33,19 +31,14 @@ const IPAdapterModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen3Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen3ModelFieldInputComponent = (
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen3Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen3ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen3ModelFieldInputComponent);

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
@@ -32,23 +30,14 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
return;
@@ -32,23 +30,14 @@ const LoRAModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const MainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,51 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { typedMemo } from 'common/util/typedMemo';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { AnyModelConfig } from 'services/api/types';
type Props<T extends AnyModelConfig> = {
value: ModelIdentifierField | undefined;
modelConfigs: T[];
isLoadingConfigs: boolean;
onChange: (value: T | null) => void;
required: boolean;
groupByType?: boolean;
};
const _ModelFieldCombobox = <T extends AnyModelConfig>({
value: _value,
modelConfigs,
isLoadingConfigs,
onChange: _onChange,
required,
groupByType,
}: Props<T>) => {
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading: isLoadingConfigs,
selectedModel: _value,
groupByType,
});
return (
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
gap={2}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
);
};
export const ModelFieldCombobox = typedMemo(_ModelFieldCombobox);

View File

@@ -1,9 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
@@ -17,7 +15,7 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const _onChange = useCallback(
const onChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
@@ -41,26 +39,15 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
groupByType: true,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
groupByType
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate,
@@ -19,7 +17,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -34,25 +32,15 @@ const RefinerModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,29 +29,15 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSDXLModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSigLipModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const SigLipModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useSigLipModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: SigLipModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const SigLipModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
@@ -21,7 +19,7 @@ const SpandrelImageToImageModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
@@ -37,19 +35,14 @@ const SpandrelImageToImageModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const T2IAdapterModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const T2IAdapterModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,12 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
@@ -16,11 +12,9 @@ type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderMode
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const isModelsTabDisabled = useAppSelector(selectIsModelsTabDisabled);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
@@ -35,31 +29,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
@@ -31,30 +29,15 @@ const VAEModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -121,6 +121,10 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_controlnets',
'metadata_to_ip_adapters',
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'chatgpt_create_image',
'chatgpt_edit_image',
];
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {

View File

@@ -23,6 +23,7 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
BooleanFieldValue,
ChatGPT4oModelFieldValue,
CLIPEmbedModelFieldValue,
CLIPGEmbedModelFieldValue,
CLIPLEmbedModelFieldValue,
@@ -38,6 +39,7 @@ import type {
ImageFieldCollectionValue,
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
@@ -61,6 +63,7 @@ import type {
import {
zBoardFieldValue,
zBooleanFieldValue,
zChatGPT4oModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
@@ -76,6 +79,7 @@ import {
zImageFieldCollectionValue,
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
@@ -512,6 +516,12 @@ export const nodesSlice = createSlice({
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
},
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -679,6 +689,8 @@ export const {
fieldFluxVAEModelValueChanged,
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,

View File

@@ -1,4 +1,5 @@
import type {
BaseModelType,
BoardField,
Classification,
ColorField,
@@ -9,9 +10,10 @@ import type {
ModelIdentifierField,
ProgressImage,
SchedulerField,
SubModelType,
T2IAdapterField,
} from 'features/nodes/types/common';
import type { Invocation, S } from 'services/api/types';
import type { Invocation, ModelType, S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
@@ -34,6 +36,9 @@ describe('Common types', () => {
// Model component types
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
// Misc types
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

View File

@@ -8,6 +8,11 @@ export const zImageField = z.object({
image_name: z.string().trim().min(1),
});
export type ImageField = z.infer<typeof zImageField>;
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
const zImageFieldCollection = z.array(zImageField);
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
zImageFieldCollection.safeParse(field).success;
export const zBoardField = z.object({
board_id: z.string().trim().min(1),
@@ -61,8 +66,20 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux', 'cogview4']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']);
const zBaseModel = z.enum([
'any',
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'sdxl-refiner',
'flux',
'cogview4',
'imagen3',
'chatgpt-4o',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
@@ -98,6 +115,7 @@ const zSubModelType = z.enum([
'scheduler',
'safety_checker',
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zModelIdentifierField = z.object({
key: z.string().min(1),
hash: z.string().min(1),

View File

@@ -248,6 +248,14 @@ const zFluxReduxModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxReduxModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -298,6 +306,8 @@ const zStatefulFieldType = z.union([
zFluxVAEModelFieldType,
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zChatGPT4oModelFieldType,
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
@@ -336,6 +346,8 @@ const modelFieldTypeNames = [
zFluxVAEModelFieldType.shape.name.value,
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
@@ -1177,6 +1189,42 @@ export const isFluxReduxModelFieldInputTemplate =
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
// #endregion
// #region Imagen3ModelField
export const zImagen3ModelFieldValue = zModelIdentifierField.optional();
const zImagen3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen3ModelFieldValue,
});
const zImagen3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen3ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen3ModelFieldValue,
});
export type Imagen3ModelFieldValue = z.infer<typeof zImagen3ModelFieldValue>;
export type Imagen3ModelFieldInputInstance = z.infer<typeof zImagen3ModelFieldInputInstance>;
export type Imagen3ModelFieldInputTemplate = z.infer<typeof zImagen3ModelFieldInputTemplate>;
export const isImagen3ModelFieldInputInstance = buildInstanceTypeGuard(zImagen3ModelFieldInputInstance);
export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zChatGPT4oModelFieldValue,
});
const zChatGPT4oModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zChatGPT4oModelFieldType,
originalType: zFieldType.optional(),
default: zChatGPT4oModelFieldValue,
});
export type ChatGPT4oModelFieldValue = z.infer<typeof zChatGPT4oModelFieldValue>;
export type ChatGPT4oModelFieldInputInstance = z.infer<typeof zChatGPT4oModelFieldInputInstance>;
export type ChatGPT4oModelFieldInputTemplate = z.infer<typeof zChatGPT4oModelFieldInputTemplate>;
export const isChatGPT4oModelFieldInputInstance = buildInstanceTypeGuard(zChatGPT4oModelFieldInputInstance);
export const isChatGPT4oModelFieldInputTemplate =
buildTemplateTypeGuard<ChatGPT4oModelFieldInputTemplate>('ChatGPT4oModelField');
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1808,6 +1856,8 @@ export const zStatefulFieldValue = z.union([
zControlLoRAModelFieldValue,
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
zFloatGeneratorFieldValue,
@@ -1898,6 +1948,8 @@ const zStatefulFieldInputTemplate = z.union([
zControlLoRAModelFieldInputTemplate,
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -1,38 +1,63 @@
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import randomInt from 'common/util/randomInt';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import { assert } from 'tsafe';
import type { ConditioningNodes, NoiseNodes } from './types';
const getExtendedPrompts = (arg: {
seedBehaviour: SeedBehaviour;
iterations: number;
prompts: string[];
model: ModelIdentifierField;
}): string[] => {
const { seedBehaviour, iterations, prompts, model } = arg;
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o') {
return range(iterations).flatMap(() => prompts);
}
return prompts;
};
export const prepareLinearUIBatch = (
state: RootState,
g: Graph,
prepend: boolean,
noise: Invocation<NoiseNodes>,
posCond: Invocation<ConditioningNodes>,
origin: 'canvas' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): EnqueueBatchArg => {
export const prepareLinearUIBatch = (arg: {
state: RootState;
g: Graph;
prepend: boolean;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
origin: 'canvas' | 'workflows' | 'upscaling';
destination: 'canvas' | 'gallery';
}): EnqueueBatchArg => {
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
assert(model, 'No model found in state when preparing batch');
const data: Batch['data'] = [];
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
// add seeds first to ensure the output order groups the prompts
if (seedBehaviour === 'PER_PROMPT') {
if (seedFieldIdentifier && seedBehaviour === 'PER_PROMPT') {
const seeds = generateSeeds({
count: prompts.length * iterations,
start: shouldRandomizeSeed ? undefined : seed,
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
firstBatchDatumList.push({
node_path: noise.id,
field_name: 'seed',
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
items: seeds,
});
@@ -43,16 +68,18 @@ export const prepareLinearUIBatch = (
field_name: 'seed',
items: seeds,
});
} else {
} else if (seedFieldIdentifier && seedBehaviour === 'PER_ITERATION') {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
count: iterations,
start: shouldRandomizeSeed ? undefined : seed,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
secondBatchDatumList.push({
node_path: noise.id,
field_name: 'seed',
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
items: seeds,
});
@@ -66,12 +93,12 @@ export const prepareLinearUIBatch = (
data.push(secondBatchDatumList);
}
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
// zipped batch of prompts
firstBatchDatumList.push({
node_path: posCond.id,
field_name: 'prompt',
node_path: positivePromptFieldIdentifier.nodeId,
field_name: positivePromptFieldIdentifier.fieldName,
items: extendedPrompts,
});
@@ -83,9 +110,9 @@ export const prepareLinearUIBatch = (
items: extendedPrompts,
});
if (shouldConcatPrompts && model?.base === 'sdxl') {
if (shouldConcatPrompts && model.base === 'sdxl') {
firstBatchDatumList.push({
node_path: posCond.id,
node_path: positivePromptFieldIdentifier.nodeId,
field_name: 'style',
items: extendedPrompts,
});

Some files were not shown because too many files have changed in this diff Show More