Compare commits

...

133 Commits

Author SHA1 Message Date
psychedelicious
143487a492 chore: bump version to v5.11.0 2025-05-13 14:04:45 +10:00
psychedelicious
203fa04295 feat(nodes): support bottleneck flag for nodes 2025-05-13 11:56:40 +10:00
Mary Hipp Rogers
954fce3c67 feat(ui): custom error toast support (#8001)
* support for custom error toast components, starting with usage limit

* add support for all usage limits

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-08 15:53:10 -04:00
Mary Hipp
821889148a easier way to override Whats New 2025-05-07 15:40:21 -04:00
Mary Hipp
4c248d8c2c refetch queue list on mount 2025-05-07 15:37:55 -04:00
Mary Hipp
deb75805d4 use the max for iterations passed in 2025-05-06 18:26:40 -04:00
Mary Hipp Rogers
93110654da Change feature to disable apiModels to chatGPT4oModels only (#7996)
* display credit column in queue list if shouldShowCredits is true

* change apiModels feature to chatGPT4oModels feature

* empty

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-06 14:37:03 -04:00
psychedelicious
ff0c48d532 chore(ui): prettier 2025-05-06 09:07:52 -04:00
psychedelicious
de18073814 feat(ui): support imagen3/chatgpt-4o models in canvas 2025-05-06 09:07:52 -04:00
psychedelicious
0708af9545 feat(ui): support imagen3/chatgpt-4o models in workflow editor 2025-05-06 09:07:52 -04:00
psychedelicious
1e85184c62 feat(nodes): add imagen3/chatgpt-4o field types 2025-05-06 09:07:52 -04:00
psychedelicious
11d3b8d944 feat(ui): add usage info to model picker 2025-05-06 09:07:52 -04:00
psychedelicious
bffd4afb96 chore(ui): typegen 2025-05-06 09:07:52 -04:00
psychedelicious
518a896521 feat(mm): add usage_info to model config 2025-05-06 09:07:52 -04:00
psychedelicious
2647ff141a feat(ui): add basic metadata to imagen3/chatgpt-4o graphs 2025-05-06 09:07:52 -04:00
Mary Hipp Rogers
ba0bac2aa5 add credits to queue item status changed (#7993)
* display credit column in queue list if shouldShowCredits is true

* add credits when queue item status changes

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-06 08:54:44 -04:00
psychedelicious
862e2a3e49 chore(ui): typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
d22fd32b05 typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
391e5b7f8c update schema 2025-05-05 16:09:13 -04:00
Mary Hipp
c9d2a5f59a display credit column in queue list if shouldShowCredits is true 2025-05-05 16:09:13 -04:00
Kent Keirsey
1f63b60021 Implementing support for Non-Standard LoRA Format (#7985)
* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* ruff fix

---------

Co-authored-by: Sam <bhaskarmdutt@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-05 09:40:38 -04:00
psychedelicious
a499b9f54e chore: bump version to v5.11.0rc2 2025-05-05 23:32:27 +10:00
psychedelicious
104505ea02 chore(ui): lint 2025-05-05 23:25:29 +10:00
psychedelicious
ee4002607c feat(ui): add UI to reset hf token 2025-05-05 23:25:29 +10:00
psychedelicious
fd20582cdd chore(ui): typegen 2025-05-05 23:25:29 +10:00
psychedelicious
43b0d07517 feat(api): add route to reset hf token 2025-05-05 23:25:29 +10:00
blessedcoolant
f83592a052 fix: deprecation warning in get_iso_timestemp 2025-05-05 11:45:30 +10:00
Mary Hipp
b3ee906749 add prompt validation to imagen3 graph 2025-05-01 13:02:13 -04:00
psychedelicious
5d69e9068a feat(ui): add ability to globally disable hotkeys
This will both hide the hotkey from the hotkey modal and override any other enabled status it has.
2025-05-01 10:50:34 -04:00
psychedelicious
a79136b058 fix(ui): always add selectModelsTab hotkey data to prevent unhandled exception while registering the hotkey handler 2025-05-01 10:50:34 -04:00
psychedelicious
944af4d4a9 feat(ui): show unsupported gen mode toasts as warnings intead of errors 2025-05-01 23:25:01 +10:00
psychedelicious
5e001be73a tidy(ui): remove excessive nav to mm buttons 2025-05-01 23:22:19 +10:00
psychedelicious
576a644b3a tidy(ui): modelpicker component 2025-05-01 23:22:19 +10:00
psychedelicious
703557c8a6 feat(ui): cleanup 2025-05-01 23:22:19 +10:00
psychedelicious
d59a53b3f9 feat(ui): simplify picker types 2025-05-01 23:22:19 +10:00
psychedelicious
7b8f78c2d9 fix(ui): focus bug w/ popvoer 2025-05-01 23:22:19 +10:00
psychedelicious
31ab9be79a feat(ui): iterate on picker 2025-05-01 23:22:19 +10:00
psychedelicious
5011fab85d fix(ui): restore FLUX Dev info popover to main model picker 2025-05-01 10:59:51 +10:00
psychedelicious
92bdb9fdcc chore(ui): remove unused exports 2025-05-01 10:59:51 +10:00
Mary Hipp
548e766c0b feat(ui): ability to disable generating with API models 2025-05-01 10:59:51 +10:00
Mary Hipp
ff897f74a1 send the list of reference images reversed to chatGPT so it matches displayed order 2025-04-30 15:56:38 -04:00
psychedelicious
3d29c996ed feat(ui): support img2img for chatgpt 4o w/ ref images 2025-04-30 13:39:05 +10:00
psychedelicious
42d57d1225 fix(ui): ref image layout 2025-04-30 13:39:05 +10:00
psychedelicious
193fa9395a fix(ui): match ref image model to main model when creating global ref image 2025-04-30 13:39:05 +10:00
psychedelicious
56cd839d5b feat(ui): support for ref images for chatgpt on canvas 2025-04-30 13:39:05 +10:00
ubansi
7b446ee40d docs: fix Contribute node import error
When I followed the Contribute Node documentation, I encountered an import error.
This commit fixes the error, which will help reduce debugging time for all future contributors.
2025-04-29 21:03:00 -04:00
Mary Hipp Rogers
17027c4070 Maryhipp/chatgpt UI (#7969)
* add GPTimage1 as allowed base model

* fix for non-disabled inpaint layers

* lots of boilerplate for adding gpt-image base model and disabling things along with imagen

* handle gpt-image dimensions

* build graph for gpt-image

* lint

* feat(ui): make chatgpt model naming consistent

* feat(ui): graph builder naming

* feat(ui): disable img2img for imagen3

* feat(ui): more naming

* feat(ui): support presigned url prefetch

* feat(ui): disable neg prompt for chatgpt

* docs(ui): update docstring

* feat(ui): fix graph building issues for chatgpt

* fix(ui): node ids for chatgpt/imagen

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-04-29 09:38:03 -04:00
psychedelicious
13d44f47ce chore(ui): prettier 2025-04-29 09:12:49 +10:00
psychedelicious
550fbdeb1c fix(ui): more types fixes 2025-04-29 09:12:49 +10:00
psychedelicious
a01cd7c497 fix(ui): add chatgpt-4o to zod schemas that need to match autogenerated types 2025-04-29 09:12:49 +10:00
Mary Hipp
c54afd600c typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
4f911a0ea8 typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
fb91f48722 change base model for chatGPT 4o 2025-04-29 09:12:49 +10:00
psychedelicious
69db60a614 fix(ui): toast typo 2025-04-29 06:56:36 +10:00
Mary Hipp
c6d7f951aa typegen 2025-04-28 15:39:11 -04:00
Mary Hipp
04c005284c add gpt-image to possible base model types 2025-04-28 15:39:11 -04:00
psychedelicious
2d7f9697bf chore(ui): lint 2025-04-28 13:31:26 -04:00
psychedelicious
ae530492a2 chore(ui): typegen 2025-04-28 13:31:26 -04:00
psychedelicious
87ed1e3b6d feat(ui): do not allow imagen3 nodes in published workflows 2025-04-28 13:31:26 -04:00
psychedelicious
cc54466db9 fix(nodes): default value for UIConfigBase.tags 2025-04-28 13:31:26 -04:00
psychedelicious
cbdafe7e38 feat(nodes): allow node clobbering 2025-04-28 13:31:26 -04:00
psychedelicious
112cb76174 fix: random seed for edit mode imagen 2025-04-28 13:31:26 -04:00
psychedelicious
e56d41ab99 feat: rip out enhance prompt as toggleable option, imagen always randomizes seed 2025-04-28 13:31:26 -04:00
psychedelicious
273dfd86ab fix(ui): upscale builder 2025-04-28 13:31:26 -04:00
psychedelicious
871271fde5 feat(ui): rough out imagen3 support for canvas 2025-04-28 13:31:26 -04:00
psychedelicious
14944872c4 feat(mm): add model taxonomy for API models & Imagen3 as base model type 2025-04-28 13:31:26 -04:00
psychedelicious
07bcf3c446 feat(ui): port bbox select to native select 2025-04-28 13:31:26 -04:00
psychedelicious
8ed5585285 feat(nodes): move output metadata to BaseInvocationOutput 2025-04-28 09:19:43 -04:00
psychedelicious
5ce226a467 chore(ui): typegen 2025-04-28 09:19:43 -04:00
Mary Hipp
c64f20a72b remove output_metdata from schema 2025-04-28 09:19:43 -04:00
Mary Hipp
0c9c10a03a update schema 2025-04-28 09:19:43 -04:00
Mary Hipp
4a0df6b865 add optional output_metadata to baseinvocation 2025-04-28 09:19:43 -04:00
psychedelicious
ba165572bf chore: bump version to v5.11.0rc1 2025-04-28 10:10:50 +10:00
psychedelicious
c3d6a10603 fix(ui): handle minor breaking typing change from serialize-error 2025-04-28 09:53:08 +10:00
psychedelicious
4efc86299d fix(ui): type error in SettingsUpsellMenuItem 2025-04-28 09:53:08 +10:00
psychedelicious
e8c7cf63fd fix(ui): type error in canvas worker 2025-04-28 09:53:08 +10:00
psychedelicious
698b034190 chore(ui): bump deps 2025-04-28 09:53:08 +10:00
psychedelicious
3988128c40 feat(ui): add _all_ image outputs to gallery (including collections) 2025-04-28 09:49:04 +10:00
psychedelicious
c768f47365 fix(ui): dnd autoscroll in scrollable containers 2025-04-28 09:46:38 +10:00
psychedelicious
19a63abc54 fix(ui): hide file size on model picker when it is zero 2025-04-23 17:45:09 +10:00
psychedelicious
75ec36bf9a chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
d802f8e7fb feat(ui): disable search when no options 2025-04-23 17:45:09 +10:00
psychedelicious
6873e0308d feat(ui): custom fallback for model picker when no models installed 2025-04-23 17:45:09 +10:00
psychedelicious
66eb73088e feat(ui): rename user-provided extra ctx for picker from ctx to extra to be less confusing 2025-04-23 17:45:09 +10:00
psychedelicious
ed81a13eb4 docs(ui): add some comments for picker 2025-04-23 17:45:09 +10:00
psychedelicious
fbc1aae52d feat(ui): more flexible fallbacks for model picker 2025-04-23 17:45:09 +10:00
psychedelicious
ba42c3e63f feat(ui): tooltip for compact/full model picker view 2025-04-23 17:45:09 +10:00
psychedelicious
b24e820aa0 fix(ui): flash of "select a model" when changing model 2025-04-23 17:45:09 +10:00
psychedelicious
e8f6b3b77a feat(ui): split out mainmodelpicker component 2025-04-23 17:45:09 +10:00
psychedelicious
8f13518c97 feat(ui): add clear search button to model combobox 2025-04-23 17:45:09 +10:00
psychedelicious
6afbc12074 feat(ui): when no model bases selected, show all models 2025-04-23 17:45:09 +10:00
psychedelicious
6b0a56ceb9 chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
ca92497e52 feat(ui): remove description from model pciker for now 2025-04-23 17:45:09 +10:00
psychedelicious
97d45ceaf2 feat(ui): model picker filter buttons 2025-04-23 17:45:09 +10:00
psychedelicious
aeb3841a6f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
c14d33d3c1 tweak(ui): remove bg on ModelImage fallback 2025-04-23 17:45:09 +10:00
psychedelicious
676e59e072 chore(ui): bump react-resizable-panels to latest
This resolves a bug where SVG elements were ignored when checking when cursor is over a resize handle
2025-04-23 17:45:09 +10:00
psychedelicious
e7dcb6a03f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
fb95b7cc2b feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
015dc3ac0d feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
9d8a71b362 feat(ui): genericizing picker 2025-04-23 17:45:09 +10:00
psychedelicious
2eb212f393 feat(ui): onSelectId -> onSelectById 2025-04-23 17:45:09 +10:00
psychedelicious
34b268c15c feat(ui): use context for stable picker state 2025-04-23 17:45:09 +10:00
psychedelicious
9a203a64dc feat(ui): render picker in portal 2025-04-23 17:45:09 +10:00
psychedelicious
d80004e056 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
de32ed23a7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
5aed2b315d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
48db6cfc4f feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
aa7c5c281a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
87aeb7f889 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
3b3d6e413a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b6432f2de3 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
9d0a28ccae feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c3bf0a3277 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b516610c1e feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
677e717cd7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c52584e057 feat(ui): simplify ScrollableContent 2025-04-23 17:45:09 +10:00
psychedelicious
b6767441db feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
8745dbe67d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
a565d9473e feat(ui): add useStateImperative 2025-04-23 17:45:09 +10:00
psychedelicious
4dbf07c3e0 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
f6eb4d9a6b feat(ui): toast on select for demo purposes 2025-04-23 17:45:09 +10:00
psychedelicious
5037967b82 feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
4930ba48ce feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
40d2092256 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
d2e9237740 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
b191b706c1 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
4d0f760ec8 chore(ui): bump cmdk to latest 2025-04-23 17:45:09 +10:00
psychedelicious
65cda5365a feat(ui): remove go to mm button from node fields 2025-04-23 17:45:09 +10:00
psychedelicious
1f2d1d086f feat(ui): add <NavigateToModelManagerButton /> to model comboboxes everywhere 2025-04-23 17:45:09 +10:00
psychedelicious
418f3c3f19 feat(ui): abstract out workflow editor model combobox, ensure consistent ui for all model fields 2025-04-23 17:45:09 +10:00
psychedelicious
72173e284c fix(ui): useModelCombobox should use null for no value instead of undefined
This fixes an issue where the refiner combobox doesn't clear itself visually when clicking the little X icon to clear the selection.
2025-04-23 17:45:09 +10:00
psychedelicious
9cc13556aa feat(ui): accept callback to override navigate to model manager functionality
If provided, `<NavigateToModelManagerButton />` will render, even if `disabledTabs` includes "models". If provided, `<NavigateToModelManagerButton />` will run the callback instead of switching tabs within the studio.

The button's tooltip is now just "Manage Models" and its icon is the same as the model manager tab's icon ([CUBE!](https://www.youtube.com/watch?v=4aGDCE6Nrz0)).
2025-04-23 17:45:09 +10:00
162 changed files with 5445 additions and 3420 deletions

View File

@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
folder for more examples:
```py
from .cool_node import CoolInvocation
from .cool_node import ResizeInvocation
```
## Creating A New Invocation
@@ -69,7 +69,10 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.invocation_api import (
BaseInvocation,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -103,8 +106,12 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -128,8 +135,12 @@ image: ImageField = InputField(description="The input image")
Great. Now let us create our other inputs for `width` and `height`
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -163,8 +174,13 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -191,8 +207,14 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation('resize')
@@ -217,9 +239,15 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
So let's do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation):

View File

@@ -893,6 +893,12 @@ class HFTokenHelper:
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@classmethod
def reset_token(cls) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.logout()
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
@@ -915,3 +921,8 @@ async def do_hf_login(
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
async def reset_hf_token() -> HFTokenStatus:
return HFTokenHelper.reset_token()

View File

@@ -25,7 +25,7 @@ from typing import (
)
import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
@@ -72,13 +72,24 @@ class Classification(str, Enum, metaclass=MetaEnum):
Special = "special"
class Bottleneck(str, Enum, metaclass=MetaEnum):
"""
The bottleneck of an invocation.
- `Network`: The invocation's execution is network-bound.
- `GPU`: The invocation's execution is GPU-bound.
"""
Network = "network"
GPU = "gpu"
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @invocation decorator logic. Do not use this directly.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: str = Field(
@@ -100,6 +111,12 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
output_meta: Optional[dict[str, JsonValue]] = Field(
default=None,
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -235,6 +252,8 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
bottleneck: ClassVar[Bottleneck]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
@@ -256,6 +275,26 @@ class InvocationRegistry:
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
invocation_type = invocation.get_type()
node_pack = invocation.UIConfig.node_pack
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
if clobbered_invocation is not None:
# This should always be true - we just checked if the invocation type was in the set
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The invocation being clobbered is a core invocation
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
else:
# The invocation being clobbered is a custom invocation
logger.warning(
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
)
cls._invocation_classes.remove(clobbered_invocation)
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@@ -314,6 +353,15 @@ class InvocationRegistry:
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
output_type = output.get_type()
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
if clobbered_output is not None:
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
logger.warning(f'Overriding invocation output "{output_type}"')
cls._output_classes.remove(clobbered_output)
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@@ -322,6 +370,11 @@ class InvocationRegistry:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
"""Gets a map of all output types to their output classes."""
return {i.get_type(): i for i in cls.get_output_classes()}
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
@@ -347,6 +400,11 @@ class InvocationRegistry:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
@classmethod
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
"""Gets the output class for a given output type."""
return cls.get_outputs_map().get(output_type)
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
@@ -354,11 +412,12 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"use_cache",
"type",
"workflow",
"bottleneck",
}
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
class _Model(BaseModel):
@@ -438,6 +497,7 @@ def invocation(
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
bottleneck: Bottleneck = Bottleneck.GPU,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@@ -449,6 +509,7 @@ def invocation(
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@@ -460,25 +521,6 @@ def invocation(
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
@@ -504,6 +546,8 @@ def invocation(
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache
cls.bottleneck = bottleneck
# Add the invocation type to the model.
# You'd be tempted to just add the type field and rebuild the model, like this:
@@ -572,13 +616,9 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}

View File

@@ -61,6 +61,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
# endregion
# region Misc Field Types

View File

@@ -241,6 +241,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def build(
@@ -263,6 +264,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
credits=queue_item.credits,
)

View File

@@ -257,6 +257,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":

View File

@@ -61,6 +61,10 @@ def get_openapi_func(
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
# Remove output_metadata that is only used on back-end from the schema
if "output_meta" in json_schema["properties"]:
json_schema["properties"].pop("output_meta")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema

View File

@@ -10,7 +10,7 @@ def get_timestamp() -> int:
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
return datetime.datetime.now(datetime.timezone.utc).isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:

View File

@@ -144,6 +144,7 @@ class ModelConfigBase(ABC, BaseModel):
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
@@ -600,6 +601,21 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
}
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
# API models are not stored on disk, so we can't match them.
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -667,6 +683,7 @@ AnyModelConfig = Annotated[
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -13,6 +13,12 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
# up matrix and down matrix have different ranks so we can't simply multiply them
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)

View File

@@ -26,7 +26,8 @@ class BaseModelType(str, Enum):
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
CogView4 = "cogview4"
# Kandinsky2_1 = "kandinsky-2.1"
Imagen3 = "imagen3"
ChatGPT4o = "chatgpt-4o"
class ModelType(str, Enum):
@@ -98,6 +99,7 @@ class ModelFormat(str, Enum):
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
Api = "api"
class SchedulerPredictionType(str, Enum):

View File

@@ -19,6 +19,7 @@ class LoRALayer(LoRALayerBase):
self.up = up
self.mid = mid
self.down = down
self.are_ranks_equal = up.shape[1] == down.shape[0]
@classmethod
def from_state_dict_values(
@@ -58,12 +59,42 @@ class LoRALayer(LoRALayerBase):
def _rank(self) -> int:
return self.down.shape[0]
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
"""
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
the LoRA weights have different ranks. This function handles the fusion of these differently sized
matrices.
"""
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
rank_diff = down.shape[0] / up.shape[1]
if rank_diff > 1:
rank_diff = down.shape[0] / up.shape[1]
w_down = down.chunk(int(rank_diff), dim=0)
for w_down_chunk in w_down:
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
else:
rank_diff = up.shape[1] / down.shape[0]
w_up = up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
return fused_lora
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
# up matrix and down matrix have different ranks so we can't simply multiply them
if not self.are_ranks_equal:
weight = self.fuse_weights(self.up, self.down)
return weight
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight

View File

@@ -20,6 +20,14 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_final_layer_linear.alpha
# lora_unet_final_layer_linear.lora_down.weight
# lora_unet_final_layer_linear.lora_up.weight
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
@@ -44,6 +52,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
@@ -65,6 +74,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
# Skip the final layer. This is incompatible with current model definition.
if layer_name.startswith("lora_unet_final_layer"):
continue
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict

View File

@@ -52,68 +52,68 @@
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@fontsource-variable/inter": "^5.1.0",
"@fontsource-variable/inter": "^5.2.5",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@nanostores/react": "^1.0.0",
"@reduxjs/toolkit": "2.7.0",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.5.3",
"@xyflow/react": "^12.6.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.10.0",
"i18next": "^23.15.1",
"i18next-http-backend": "^2.6.1",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.15",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.2.0",
"linkifyjs": "^4.2.0",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.1",
"lru-cache": "^11.1.0",
"mtwist": "^1.0.2",
"nanoid": "^5.0.7",
"nanostores": "^0.11.3",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
"nanoid": "^5.1.5",
"nanostores": "^1.0.1",
"new-github-issue-url": "^1.1.0",
"overlayscrollbars": "^2.11.1",
"overlayscrollbars-react": "^0.5.6",
"perfect-freehand": "^1.2.2",
"query-string": "^9.1.0",
"query-string": "^9.1.1",
"raf-throttle": "^2.0.6",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.2.9",
"react-error-boundary": "^4.0.13",
"react-hook-form": "^7.53.0",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-hook-form": "^7.56.1",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^15.0.2",
"react-icons": "^5.3.0",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.1.4",
"react-textarea-autosize": "^8.5.7",
"react-use": "^17.5.1",
"react-virtuoso": "^4.12.5",
"react-i18next": "^15.5.1",
"react-icons": "^5.5.0",
"react-redux": "9.2.0",
"react-resizable-panels": "^2.1.8",
"react-textarea-autosize": "^8.5.9",
"react-use": "^17.6.0",
"react-virtuoso": "^4.12.6",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-remember": "^5.2.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.4.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.8.0",
"stable-hash": "^0.0.4",
"use-debounce": "^10.0.3",
"serialize-error": "^12.0.0",
"socket.io-client": "^4.8.1",
"stable-hash": "^0.0.5",
"use-debounce": "^10.0.4",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"uuid": "^11.1.0",
"zod": "^3.24.3",
"zod-validation-error": "^3.4.0"
},
"peerDependencies": {
@@ -123,43 +123,43 @@
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.3.4",
"@storybook/addon-interactions": "^8.3.4",
"@storybook/addon-links": "^8.3.4",
"@storybook/addon-storysource": "^8.3.4",
"@storybook/manager-api": "^8.3.4",
"@storybook/react": "^8.3.4",
"@storybook/react-vite": "^8.5.5",
"@storybook/theming": "^8.3.4",
"@storybook/addon-essentials": "^8.6.12",
"@storybook/addon-interactions": "^8.6.12",
"@storybook/addon-links": "^8.6.12",
"@storybook/addon-storysource": "^8.6.12",
"@storybook/manager-api": "^8.6.12",
"@storybook/react": "^8.6.12",
"@storybook/react-vite": "^8.6.12",
"@storybook/theming": "^8.6.12",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.16.10",
"@types/node": "^22.15.1",
"@types/react": "^18.3.11",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^10.0.0",
"@vitejs/plugin-react-swc": "^3.8.0",
"@vitest/coverage-v8": "^3.0.6",
"@vitest/ui": "^3.0.6",
"concurrently": "^8.2.2",
"@vitejs/plugin-react-swc": "^3.9.0",
"@vitest/coverage-v8": "^3.1.2",
"@vitest/ui": "^3.1.2",
"concurrently": "^9.1.2",
"csstype": "^3.1.3",
"dpdm": "^3.14.0",
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.0",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.31.0",
"knip": "^5.50.5",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.4.1",
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.3.4",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
"rollup-plugin-visualizer": "^5.14.0",
"storybook": "^8.6.12",
"tsafe": "^1.8.5",
"type-fest": "^4.26.1",
"typescript": "^5.6.2",
"vite": "^6.1.0",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^6.3.3",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.0",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^5.1.4",
"vitest": "^3.0.6"
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "8"

File diff suppressed because it is too large Load Diff

View File

@@ -118,6 +118,8 @@
"error": "Error",
"error_withCount_one": "{{count}} error",
"error_withCount_other": "{{count}} errors",
"model_withCount_one": "{{count}} model",
"model_withCount_other": "{{count}} models",
"file": "File",
"folder": "Folder",
"format": "format",
@@ -138,6 +140,8 @@
"localSystem": "Local System",
"learnMore": "Learn More",
"modelManager": "Model Manager",
"noMatches": "No matches",
"noOptions": "No options",
"nodes": "Workflows",
"notInstalled": "Not $t(common.installed)",
"openInNewTab": "Open in New Tab",
@@ -171,6 +175,8 @@
"blue": "Blue",
"alpha": "Alpha",
"selected": "Selected",
"search": "Search",
"clear": "Clear",
"tab": "Tab",
"view": "View",
"edit": "Edit",
@@ -197,7 +203,11 @@
"column": "Column",
"value": "Value",
"label": "Label",
"systemInformation": "System Information"
"systemInformation": "System Information",
"compactView": "Compact View",
"fullView": "Full View",
"options_withCount_one": "{{count}} option",
"options_withCount_other": "{{count}} options"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -258,6 +268,7 @@
"status": "Status",
"total": "Total",
"time": "Time",
"credits": "Credits",
"pending": "Pending",
"in_progress": "In Progress",
"completed": "Completed",
@@ -768,6 +779,7 @@
"description": "Description",
"edit": "Edit",
"fileSize": "File Size",
"filterModels": "Filter models",
"fluxRedux": "FLUX Redux",
"height": "Height",
"huggingFace": "HuggingFace",
@@ -787,6 +799,7 @@
"hfTokenUnableToVerify": "Unable to Verify HF Token",
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
"hfTokenSaved": "HF Token Saved",
"hfTokenReset": "HF Token Reset",
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"imageEncoderModelId": "Image Encoder Model ID",
@@ -821,10 +834,12 @@
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"name": "Name",
"noModelsInstalled": "No Models Installed",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
"noModelsInstalledDesc1": "Install models with the",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching Models",
"noMatchingModels": "No matching models",
"noModelsInstalled": "No models installed",
"none": "none",
"path": "Path",
"pathToConfig": "Path To Config",
@@ -871,7 +886,8 @@
"installingXModels_one": "Installing {{count}} model",
"installingXModels_other": "Installing {{count}} models",
"skippingXDuplicates_one": ", skipping {{count}} duplicate",
"skippingXDuplicates_other": ", skipping {{count}} duplicates"
"skippingXDuplicates_other": ", skipping {{count}} duplicates",
"manageModels": "Manage Models"
},
"models": {
"addLora": "Add LoRA",
@@ -1093,6 +1109,7 @@
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "missing input",
@@ -1173,7 +1190,8 @@
"width": "Width",
"gaussianBlur": "Gaussian Blur",
"boxBlur": "Box Blur",
"staged": "Staged"
"staged": "Staged",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your <LinkComponent>account settings</LinkComponent> to upgrade."
},
"dynamicPrompts": {
"showDynamicPrompts": "Show Dynamic Prompts",
@@ -1312,6 +1330,8 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"

View File

@@ -5,6 +5,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
@@ -12,10 +13,13 @@ import type { CustomStarUi } from 'app/store/nanostores/customStarUI';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { $logo } from 'app/store/nanostores/logo';
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import { $projectId, $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
import { $store } from 'app/store/nanostores/store';
import { $toastMap } from 'app/store/nanostores/toastMap';
import { $whatsNew } from 'app/store/nanostores/whatsNew';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
@@ -29,6 +33,7 @@ import {
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES,
} from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { ToastConfig } from 'features/toast/toast';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@@ -45,6 +50,7 @@ interface Props extends PropsWithChildren {
token?: string;
config?: PartialAppConfig;
customNavComponent?: ReactNode;
accountSettingsLink?: string;
middleware?: Middleware[];
projectId?: string;
projectName?: string;
@@ -55,10 +61,16 @@ interface Props extends PropsWithChildren {
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;
logo?: ReactNode;
toastMap?: Record<string, ToastConfig>;
whatsNew?: ReactNode[];
workflowCategories?: WorkflowCategory[];
workflowTagCategories?: WorkflowTagCategory[];
workflowSortOptions?: WorkflowSortOption[];
loggingOverrides?: LoggingOverrides;
/**
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
}
const InvokeAIUI = ({
@@ -67,6 +79,7 @@ const InvokeAIUI = ({
token,
config,
customNavComponent,
accountSettingsLink,
middleware,
projectId,
projectName,
@@ -77,10 +90,13 @@ const InvokeAIUI = ({
socketOptions,
isDebugging = false,
logo,
toastMap,
workflowCategories,
workflowTagCategories,
workflowSortOptions,
loggingOverrides,
onClickGoToModelManager,
whatsNew,
}: Props) => {
useLayoutEffect(() => {
/*
@@ -169,6 +185,16 @@ const InvokeAIUI = ({
};
}, [customNavComponent]);
useEffect(() => {
if (accountSettingsLink) {
$accountSettingsLink.set(accountSettingsLink);
}
return () => {
$accountSettingsLink.set(undefined);
};
}, [accountSettingsLink]);
useEffect(() => {
if (openAPISchemaUrl) {
$openAPISchemaUrl.set(openAPISchemaUrl);
@@ -205,6 +231,36 @@ const InvokeAIUI = ({
};
}, [logo]);
useEffect(() => {
if (toastMap) {
$toastMap.set(toastMap);
}
return () => {
$toastMap.set(undefined);
};
}, [toastMap]);
useEffect(() => {
if (whatsNew) {
$whatsNew.set(whatsNew);
}
return () => {
$whatsNew.set(undefined);
};
}, [whatsNew]);
useEffect(() => {
if (onClickGoToModelManager) {
$onClickGoToModelManager.set(onClickGoToModelManager);
}
return () => {
$onClickGoToModelManager.set(undefined);
};
}, [onClickGoToModelManager]);
useEffect(() => {
if (workflowCategories) {
$workflowLibraryCategoriesOptions.set(workflowCategories);

View File

@@ -1,3 +1,4 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
@@ -6,11 +7,14 @@ import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
@@ -48,32 +52,50 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status: 'error',
title: 'Failed to build graph',
status,
title,
description,
});
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
@@ -89,7 +111,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}

View File

@@ -18,16 +18,24 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const state = getState();
const { prepend } = action.payload;
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}

View File

@@ -0,0 +1,3 @@
import { atom } from 'nanostores';
export const $accountSettingsLink = atom<string | undefined>(undefined);

View File

@@ -0,0 +1,3 @@
import { atom } from 'nanostores';
export const $onClickGoToModelManager = atom<(() => void) | undefined>(undefined);

View File

@@ -0,0 +1,4 @@
import type { ToastConfig } from 'features/toast/toast';
import { atom } from 'nanostores';
export const $toastMap = atom<Record<string, ToastConfig> | undefined>(undefined);

View File

@@ -0,0 +1,4 @@
import { atom } from 'nanostores';
import type { ReactNode } from 'react';
export const $whatsNew = atom<ReactNode[] | undefined>(undefined);

View File

@@ -145,7 +145,10 @@ const unserialize: UnserializeFunction = (data, key) => {
);
return transformed;
} catch (err) {
log.warn({ error: serializeError(err) }, `Error rehydrating slice "${key}", falling back to default initial state`);
log.warn(
{ error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
return persistConfig.initialState;
}
};

View File

@@ -28,7 +28,8 @@ export type AppFeature =
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll';
| 'cancelAndClearAll'
| 'chatGPT4oModels';
/**
* A disable-able Stable Diffusion feature
*/
@@ -83,6 +84,7 @@ export type AppConfig = {
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
shouldShowCredits: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

File diff suppressed because it is too large Load Diff

View File

@@ -38,7 +38,7 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
}, [optionsFilter, getIsDisabled, modelConfigs, shouldShowModelDescriptions]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
[options, selectedModel]
);

View File

@@ -1,6 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { memo } from 'react';
/**
* A typed version of React.memo, useful for components that take generics.
*/
export const typedMemo: <T>(c: T) => T = memo;
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
component: T,
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
) => T & { displayName?: string } = memo;

View File

@@ -24,6 +24,7 @@ export const CanvasAddEntityButtons = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -52,6 +53,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
isDisabled={!isInpaintLayerEnabled}
>
{t('controlLayers.inpaintMask')}
</Button>

View File

@@ -25,6 +25,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Menu>
@@ -46,7 +47,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={!isRegionalGuidanceEnabled}>

View File

@@ -0,0 +1,63 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
};
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
if (!modelConfig) {
return;
}
onChangeModel(modelConfig);
},
[onChangeModel]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChangeModel,
selectedModel,
getIsDisabled,
isLoading,
});
return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
);
});
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';

View File

@@ -61,7 +61,7 @@ export const IPAdapterImagePreview = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" />
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}

View File

@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -4,29 +4,26 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
};
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
@@ -71,4 +68,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
);
});
IPAdapterModel.displayName = 'IPAdapterModel';
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';

View File

@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import {
modelConfigsAdapterSelectors,
selectMainModelConfig,
selectModelConfigsQuery,
} from 'services/api/endpoints/models';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
}
);
export const selectDefaultRefImageConfig = createSelector(
selectMainModelConfig,
selectModelConfigsQuery,
selectBase,
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
if (selectedMainModel?.base === 'chatgpt-4o') {
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}
const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
if (model.base === 'flux') {
ipAdapter.clipVisionModel = 'ViT-L';
}
}
return ipAdapter;
}
);
/**
* Selects the default IP adapter configuration based on the model configurations and the base.
*
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
export const useAddGlobalReferenceImage = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
const func = useCallback(() => {
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
const overrides = { ipAdapter: deepClone(defaultRefImage) };
dispatch(referenceImageAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
}, [defaultRefImage, dispatch]);
return func;
};

View File

@@ -41,7 +41,7 @@ export const useCopyLayerToClipboard = () => {
});
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemCopyingLayer'),
@@ -82,7 +82,7 @@ export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to save canvas to gallery');
log.error({ error: serializeError(error as Error) }, 'Failed to save canvas to gallery');
toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
}
}, [canvasManager.compositor, canvasManager.stateApi, clipboard, region, t]);

View File

@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHook
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
@@ -198,7 +198,7 @@ export const useNewRegionalReferenceImageFromBbox = () => {
export const useNewGlobalReferenceImageFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultIPAdapter = useAppSelector(selectDefaultRefImageConfig);
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO) => {

View File

@@ -1,5 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsCogView4, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -8,23 +13,25 @@ import { assert } from 'tsafe';
export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3;
case 'regional_guidance':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'inpaint_mask':
return true;
return !isImagen3 && !isChatGPT4o;
case 'raster_layer':
return true;
return !isImagen3 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4]);
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -41,7 +41,7 @@ export const useSaveLayerToAssets = () => {
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemSavingLayer'),

View File

@@ -519,7 +519,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
} catch (error) {
this.log.error({ rasterizeArgs, error: serializeError(error) }, 'Failed to rasterize entity');
this.log.error({ rasterizeArgs, error: serializeError(error as Error) }, 'Failed to rasterize entity');
throw error;
} finally {
this.manager.stateApi.$rasterizingAdapter.set(null);

View File

@@ -346,7 +346,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
// If the user is not holding shift, the transform is retaining aspect ratio. It's not possible to snap to the grid
// in this case, because that would change the aspect ratio. So, we only snap to the grid when shift is held.
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getGridSize() : 1;
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getPositionGridSize() : 1;
// We need to snap the anchor to the selected grid size, but the positions provided to this callback are absolute,
// scaled coordinates. They need to be converted to stage coordinates, snapped, then converted back to absolute
@@ -464,7 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -498,7 +498,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -523,7 +523,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
onDragMove = () => {
// Snap the interaction rect to the grid
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
this.konva.proxyRect.x(roundToMultiple(this.konva.proxyRect.x(), gridSize));
this.konva.proxyRect.y(roundToMultiple(this.konva.proxyRect.y(), gridSize));

View File

@@ -112,7 +112,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
return;
}
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url));
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));

View File

@@ -493,7 +493,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
* Gets the _positional_ grid size for the current canvas. Note that this is not the same as bbox grid size, which is
* based on the currently-selected model.
*/
getGridSize = (): number => {
getPositionGridSize = (): number => {
const snapToGrid = this.getSettings().snapToGrid;
if (!snapToGrid) {
return 1;

View File

@@ -4,8 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import Konva from 'konva';
import { noop } from 'lodash-es';
import { atom } from 'nanostores';
@@ -178,6 +180,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
// Listen for the bbox overlay setting to update the overlay's visibility
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBboxOverlay, this.render));
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
// Update on busy state changes
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
}
@@ -218,12 +223,25 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
this.syncOverlay();
const model = this.manager.stateApi.runSelector(selectModel);
this.konva.transformer.setAttrs({
listening: tool === 'bbox',
enabledAnchors: tool === 'bbox' ? ALL_ANCHORS : NO_ANCHORS,
enabledAnchors: this.getEnabledAnchors(tool, model),
});
};
getEnabledAnchors = (tool: Tool, model?: ModelIdentifierField | null): string[] => {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}
return ALL_ANCHORS;
};
syncOverlay = () => {
const bboxOverlay = this.manager.stateApi.getSettings().bboxOverlay;
@@ -251,7 +269,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
onDragMove = () => {
// The grid size here is the _position_ grid size, not the _dimension_ grid size - it is not constratined by the
// currently-selected model.
const gridSize = this.manager.stateApi.getGridSize();
const gridSize = this.manager.stateApi.getPositionGridSize();
const bbox = this.manager.stateApi.getBbox();
const bboxRect: Rect = {
...bbox.rect,

View File

@@ -476,15 +476,24 @@ export function getImageDataTransparency(imageData: ImageData): Transparency {
/**
* Loads an image from a URL and returns a promise that resolves with the loaded image element.
* @param src The image source URL
* @param fetchUrlFirst Whether to fetch the image's URL first, assuming the provided `src` will redirect to a different URL. This addresses an issue where CORS headers are dropped during a redirect.
* @returns A promise that resolves with the loaded image element
*/
export function loadImage(src: string): Promise<HTMLImageElement> {
export async function loadImage(src: string, fetchUrlFirst?: boolean): Promise<HTMLImageElement> {
const authToken = $authToken.get();
let url = src;
if (authToken && fetchUrlFirst) {
const response = await fetch(`${src}?url_only=true`, { credentials: 'include' });
const data = await response.json();
url = data.url;
}
return new Promise((resolve, reject) => {
const imageElement = new Image();
imageElement.onload = () => resolve(imageElement);
imageElement.onerror = (error) => reject(error);
imageElement.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
imageElement.src = src;
imageElement.src = url;
});
}

View File

@@ -10,12 +10,12 @@ export type Extents = {
/**
* Get the bounding box of an image.
* @param buffer The ArrayBuffer of the image to get the bounding box of.
* @param buffer The ArrayBufferLike of the image to get the bounding box of.
* @param width The width of the image.
* @param height The height of the image.
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/
const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height: number): Extents | null => {
const getImageDataBboxArrayBufferLike = (buffer: ArrayBufferLike, width: number, height: number): Extents | null => {
let minX = width;
let minY = height;
let maxX = -1;
@@ -50,7 +50,7 @@ const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height:
export type GetBboxTask = {
type: 'get_bbox';
data: { id: string; buffer: ArrayBuffer; width: number; height: number };
data: { id: string; buffer: ArrayBufferLike; width: number; height: number };
};
type TaskWithTimestamps<T extends Record<string, unknown>> = T & { started: number | null; finished: number | null };
@@ -95,7 +95,7 @@ function processNextTask() {
// Process the task
if (task.type === 'get_bbox') {
const { buffer, width, height, id } = task.data;
const extents = getImageDataBboxArrayBuffer(buffer, width, height);
const extents = getImageDataBboxArrayBufferLike(buffer, width, height);
const result: ExtentsResult = {
type: 'extents',
data: { id, extents },

View File

@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import { isEqual, merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type {
ApiModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
@@ -67,7 +68,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -76,6 +77,7 @@ import {
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFLUXRedux,
@@ -644,7 +646,10 @@ export const canvasSlice = createSlice({
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
EntityIdentifierPayload<
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
'reference_image'
>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
@@ -652,14 +657,36 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
const oldModel = entity.ipAdapter.model;
// First set the new model
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
if (!entity.ipAdapter.model) {
return;
}
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
if (isEqual(oldModel, entity.ipAdapter.model)) {
// Nothing changed, so we don't need to do anything
return;
}
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
// When we switch the model, we keep the image the same, but change the other parameters.
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
// Switching to chatgpt-4o ref image
entity.ipAdapter = {
...initialChatGPT4oReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
@@ -668,17 +695,13 @@ export const canvasSlice = createSlice({
return;
}
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
if (entity.ipAdapter.model.type === 'ip_adapter') {
// Switching to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
@@ -686,6 +709,7 @@ export const canvasSlice = createSlice({
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
return;
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
@@ -1139,7 +1163,21 @@ export const canvasSlice = createSlice({
syncScaledSize(state);
},
bboxChangedFromCanvas: (state, action: PayloadAction<IRect>) => {
state.bbox.rect = action.payload;
const newBboxRect = action.payload;
const oldBboxRect = state.bbox.rect;
state.bbox.rect = newBboxRect;
if (newBboxRect.width === oldBboxRect.width && newBboxRect.height === oldBboxRect.height) {
return;
}
const oldAspectRatio = state.bbox.aspectRatio.value;
const newAspectRatio = newBboxRect.width / newBboxRect.height;
if (oldAspectRatio === newAspectRatio) {
return;
}
// TODO(psyche): Figure out a way to handle this without resetting the aspect ratio on every change.
// This action is dispatched when the user resizes or moves the bbox from the canvas. For now, when the user
@@ -1198,6 +1236,40 @@ export const canvasSlice = createSlice({
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
state.bbox.rect.height = 768;
} else if (id === '4:3') {
state.bbox.rect.width = 1280;
state.bbox.rect.height = 896;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '3:4') {
state.bbox.rect.width = 896;
state.bbox.rect.height = 1280;
} else if (id === '9:16') {
state.bbox.rect.width = 768;
state.bbox.rect.height = 1408;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '3:2') {
state.bbox.rect.width = 1536;
state.bbox.rect.height = 1024;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '2:3') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1536;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
@@ -1670,6 +1742,13 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3' || base === 'chatgpt-4o') {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
}
syncScaledSize(state);
}
});
@@ -1802,6 +1881,10 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}
if (state.bbox.scaleMethod === 'auto') {
// Sync both aspect ratio and size
const { width, height } = state.bbox.rect;

View File

@@ -380,6 +380,8 @@ export const selectIsSDXL = createParamsSelector((params) => params.model?.base
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);

View File

@@ -245,6 +245,18 @@ const zFLUXReduxConfig = z.object({
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zChatGPT4oReferenceImageConfig = z.object({
type: z.literal('chatgpt_4o_reference_image'),
image: zImageWithDims.nullable(),
/**
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -254,15 +266,19 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
export const isIPAdapterConfig = (config: CanvasReferenceImageState['ipAdapter']): config is IPAdapterConfig =>
config.type === 'ip_adapter';
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']): config is FLUXReduxConfig =>
config.type === 'flux_redux';
export const isChatGPT4oReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
@@ -387,9 +403,18 @@ export type StagingAreaImage = {
offsetY: number;
};
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
const zCanvasState = z.object({
_version: z.literal(3),

View File

@@ -7,6 +7,7 @@ import type {
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
FLUXReduxConfig,
@@ -77,6 +78,11 @@ export const initialFLUXRedux: FLUXReduxConfig = {
model: null,
imageInfluence: 'highest',
};
export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
type: 'chatgpt_4o_reference_image',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -4,7 +4,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
export interface DynamicPromptsState {

View File

@@ -1,6 +1,6 @@
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
@@ -116,7 +116,7 @@ export const createNewCanvasEntityFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
break;
@@ -238,7 +238,7 @@ export const newCanvasFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(canvasReset());
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));

View File

@@ -58,7 +58,7 @@ const LoRASelect = () => {
const noOptionsMessage = useCallback(() => t('models.noMatchingLoRAs'), [t]);
return (
<FormControl isDisabled={!options.length}>
<FormControl isDisabled={!options.length} gap={2}>
<InformationalPopover feature="lora">
<FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover>

View File

@@ -12,63 +12,45 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
import { UNAUTHORIZED_TOAST_ID } from 'services/events/onModelInstallError';
import {
useGetHFTokenStatusQuery,
useResetHFTokenMutation,
useSetHFTokenMutation,
} from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
export const HFToken = () => {
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const [token, setToken] = useState('');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [trigger, { isLoading, isUninitialized }] = useSetHFTokenMutation();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
if (res === 'valid') {
setToken('');
toast({
id: UNAUTHORIZED_TOAST_ID,
title: t('modelManager.hfTokenSaved'),
status: 'success',
duration: 3000,
});
}
});
}, [t, token, trigger]);
const error = useMemo(() => {
if (!currentData || isUninitialized || isLoading) {
return null;
switch (currentData) {
case 'invalid':
return t('modelManager.hfTokenInvalidErrorMessage');
case 'unknown':
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
case 'valid':
case undefined:
return null;
default:
assert<Equals<never, typeof currentData>>(false, 'Unexpected HF token status');
}
if (currentData === 'invalid') {
return t('modelManager.hfTokenInvalidErrorMessage');
}
if (currentData === 'unknown') {
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
}
return null;
}, [currentData, isLoading, isUninitialized, t]);
}, [currentData, t]);
if (!currentData || currentData === 'valid') {
if (!currentData) {
return null;
}
return (
<Flex borderRadius="base" w="full">
<FormControl isInvalid={!isUninitialized && Boolean(error)} orientation="vertical">
<FormControl isInvalid={Boolean(error)} orientation="vertical">
<FormLabel>{t('modelManager.hfTokenLabel')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
{error && <SetHFTokenInput />}
{!error && <ResetHFTokenButton />}
<FormHelperText>
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
</FormHelperText>
@@ -77,3 +59,73 @@ export const HFToken = () => {
</Flex>
);
};
const PLACEHOLDER_TOKEN = Array.from({ length: 37 }, () => 'a').join('');
const ResetHFTokenButton = memo(() => {
const { t } = useTranslation();
const [resetHFToken, { isLoading }] = useResetHFTokenMutation();
const onClick = useCallback(() => {
resetHFToken()
.unwrap()
.then(() => {
toast({
title: t('modelManager.hfTokenReset'),
status: 'info',
});
});
}, [resetHFToken, t]);
return (
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={PLACEHOLDER_TOKEN} isDisabled />
<Button onClick={onClick} size="sm" isLoading={isLoading}>
{t('common.reset')}
</Button>
</Flex>
);
});
ResetHFTokenButton.displayName = 'ResetHFTokenButton';
const SetHFTokenInput = memo(() => {
const { t } = useTranslation();
const [token, setToken] = useState('');
const [trigger, { isLoading }] = useSetHFTokenMutation();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
switch (res) {
case 'valid':
setToken('');
toast({
title: t('modelManager.hfTokenSaved'),
status: 'success',
});
break;
case 'invalid':
case 'unknown':
default:
toast({
title: t('modelManager.hfTokenUnableToVerify'),
status: 'error',
});
break;
}
});
}, [t, token, trigger]);
return (
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
);
});
SetHFTokenInput.displayName = 'SetHFTokenInput';

View File

@@ -1,11 +1,10 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HFToken } from './HFToken';
import { HuggingFaceResults } from './HuggingFaceResults';
@@ -16,7 +15,6 @@ export const HuggingFaceForm = memo(() => {
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModel();
@@ -68,7 +66,7 @@ export const HuggingFaceForm = memo(() => {
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{currentData !== 'valid' && <HFToken />}
{isHFTokenEnabled && <HFToken />}
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);

View File

@@ -7,7 +7,7 @@ type Props = {
base: BaseModelType;
};
const BASE_COLOR_MAP: Record<BaseModelType, string> = {
export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
@@ -15,12 +15,14 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
cogview4: 'orange',
cogview4: 'red',
imagen3: 'pink',
'chatgpt-4o': 'pink',
};
const ModelBaseBadge = ({ base }: Props) => {
return (
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
<Badge flexGrow={0} flexShrink={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
{MODEL_TYPE_SHORT_MAP[base]}
</Badge>
);

View File

@@ -17,6 +17,7 @@ const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
bnb_quantized_int8b: 'bnb_quantized_int8b',
bnb_quantized_nf4b: 'quantized',
gguf_quantized: 'gguf',
api: 'api',
};
const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
@@ -30,6 +31,7 @@ const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
bnb_quantized_int8b: 'base',
bnb_quantized_nf4b: 'base',
gguf_quantized: 'base',
api: 'base',
};
const ModelFormatBadge = ({ format }: Props) => {

View File

@@ -15,7 +15,6 @@ const ModelImage = ({ image_url }: Props) => {
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
bg="base.650"
borderRadius="base"
alignItems="center"
justifyContent="center"

View File

@@ -1,10 +1,12 @@
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -23,6 +25,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
@@ -57,6 +61,8 @@ import {
isImageFieldInputTemplate,
isImageGeneratorFieldInputInstance,
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
@@ -394,6 +400,20 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen3ModelFieldInputTemplate(template)) {
if (!isImagen3ModelFieldInputInstance(field)) {
return null;
}
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
}
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isColorFieldInputTemplate(template)) {
if (!isColorFieldInputInstance(field)) {
return null;

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedMode
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
@@ -34,32 +29,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedMo
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
@@ -35,32 +30,15 @@ const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedMo
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
@@ -35,32 +30,15 @@ const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const ChatGPT4oModelFieldInputComponent = (
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldChatGPT4oModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ChatGPT4oModelFieldInputComponent);

View File

@@ -1,16 +1,13 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
ControlLoRAModelFieldInputInstance,
ControlLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
import { type ControlLoRAModelConfig, isControlLoRAModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -18,12 +15,10 @@ type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRA
const ControlLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlLoRAModel();
const _onChange = useCallback(
const onChange = useCallback(
(value: ControlLoRAModelConfig | null) => {
if (!value) {
return;
@@ -38,32 +33,15 @@ const ControlLoRAModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isControlLoRAModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
@@ -17,7 +15,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: ControlNetModelConfig | null) => {
if (!value) {
return;
@@ -33,25 +31,14 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const FluxReduxModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const FluxReduxModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,11 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFie
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
@@ -34,27 +29,15 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
@@ -17,7 +15,7 @@ const IPAdapterModelFieldInputComponent = (
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
if (!value) {
return;
@@ -33,19 +31,14 @@ const IPAdapterModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen3Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen3ModelFieldInputComponent = (
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen3Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen3ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen3ModelFieldInputComponent);

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
@@ -32,23 +30,14 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
return;
@@ -32,23 +30,14 @@ const LoRAModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const MainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -0,0 +1,51 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { typedMemo } from 'common/util/typedMemo';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { AnyModelConfig } from 'services/api/types';
type Props<T extends AnyModelConfig> = {
value: ModelIdentifierField | undefined;
modelConfigs: T[];
isLoadingConfigs: boolean;
onChange: (value: T | null) => void;
required: boolean;
groupByType?: boolean;
};
const _ModelFieldCombobox = <T extends AnyModelConfig>({
value: _value,
modelConfigs,
isLoadingConfigs,
onChange: _onChange,
required,
groupByType,
}: Props<T>) => {
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading: isLoadingConfigs,
selectedModel: _value,
groupByType,
});
return (
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
gap={2}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
);
};
export const ModelFieldCombobox = typedMemo(_ModelFieldCombobox);

View File

@@ -1,9 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
@@ -17,7 +15,7 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const _onChange = useCallback(
const onChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
@@ -41,26 +39,15 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
groupByType: true,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
groupByType
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate,
@@ -19,7 +17,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -34,25 +32,15 @@ const RefinerModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,29 +29,15 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSDXLModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
@@ -31,25 +29,15 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSigLipModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const SigLipModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useSigLipModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: SigLipModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const SigLipModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
@@ -21,7 +19,7 @@ const SpandrelImageToImageModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
@@ -37,19 +35,14 @@ const SpandrelImageToImageModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
@@ -18,7 +16,7 @@ const T2IAdapterModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
if (!value) {
return;
@@ -34,19 +32,14 @@ const T2IAdapterModelFieldInputComponent = (
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,12 +1,8 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
@@ -16,11 +12,9 @@ type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderMode
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const isModelsTabDisabled = useAppSelector(selectIsModelsTabDisabled);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
@@ -35,31 +29,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -1,8 +1,6 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
@@ -16,7 +14,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
@@ -31,30 +29,15 @@ const VAEModelFieldInputComponent = (props: Props) => {
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && required}
>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

View File

@@ -121,6 +121,10 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_controlnets',
'metadata_to_ip_adapters',
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'chatgpt_create_image',
'chatgpt_edit_image',
];
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {

View File

@@ -23,6 +23,7 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
BooleanFieldValue,
ChatGPT4oModelFieldValue,
CLIPEmbedModelFieldValue,
CLIPGEmbedModelFieldValue,
CLIPLEmbedModelFieldValue,
@@ -38,6 +39,7 @@ import type {
ImageFieldCollectionValue,
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
@@ -61,6 +63,7 @@ import type {
import {
zBoardFieldValue,
zBooleanFieldValue,
zChatGPT4oModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
@@ -76,6 +79,7 @@ import {
zImageFieldCollectionValue,
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
@@ -512,6 +516,12 @@ export const nodesSlice = createSlice({
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
},
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -679,6 +689,8 @@ export const {
fieldFluxVAEModelValueChanged,
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,

View File

@@ -1,4 +1,5 @@
import type {
BaseModelType,
BoardField,
Classification,
ColorField,
@@ -9,9 +10,10 @@ import type {
ModelIdentifierField,
ProgressImage,
SchedulerField,
SubModelType,
T2IAdapterField,
} from 'features/nodes/types/common';
import type { Invocation, S } from 'services/api/types';
import type { Invocation, ModelType, S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
@@ -34,6 +36,9 @@ describe('Common types', () => {
// Model component types
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
// Misc types
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

View File

@@ -8,6 +8,11 @@ export const zImageField = z.object({
image_name: z.string().trim().min(1),
});
export type ImageField = z.infer<typeof zImageField>;
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
const zImageFieldCollection = z.array(zImageField);
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
zImageFieldCollection.safeParse(field).success;
export const zBoardField = z.object({
board_id: z.string().trim().min(1),
@@ -61,8 +66,20 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux', 'cogview4']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']);
const zBaseModel = z.enum([
'any',
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'sdxl-refiner',
'flux',
'cogview4',
'imagen3',
'chatgpt-4o',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
@@ -98,6 +115,7 @@ const zSubModelType = z.enum([
'scheduler',
'safety_checker',
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zModelIdentifierField = z.object({
key: z.string().min(1),
hash: z.string().min(1),

View File

@@ -248,6 +248,14 @@ const zFluxReduxModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxReduxModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -298,6 +306,8 @@ const zStatefulFieldType = z.union([
zFluxVAEModelFieldType,
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zChatGPT4oModelFieldType,
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
@@ -336,6 +346,8 @@ const modelFieldTypeNames = [
zFluxVAEModelFieldType.shape.name.value,
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
@@ -1177,6 +1189,42 @@ export const isFluxReduxModelFieldInputTemplate =
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
// #endregion
// #region Imagen3ModelField
export const zImagen3ModelFieldValue = zModelIdentifierField.optional();
const zImagen3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen3ModelFieldValue,
});
const zImagen3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen3ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen3ModelFieldValue,
});
export type Imagen3ModelFieldValue = z.infer<typeof zImagen3ModelFieldValue>;
export type Imagen3ModelFieldInputInstance = z.infer<typeof zImagen3ModelFieldInputInstance>;
export type Imagen3ModelFieldInputTemplate = z.infer<typeof zImagen3ModelFieldInputTemplate>;
export const isImagen3ModelFieldInputInstance = buildInstanceTypeGuard(zImagen3ModelFieldInputInstance);
export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zChatGPT4oModelFieldValue,
});
const zChatGPT4oModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zChatGPT4oModelFieldType,
originalType: zFieldType.optional(),
default: zChatGPT4oModelFieldValue,
});
export type ChatGPT4oModelFieldValue = z.infer<typeof zChatGPT4oModelFieldValue>;
export type ChatGPT4oModelFieldInputInstance = z.infer<typeof zChatGPT4oModelFieldInputInstance>;
export type ChatGPT4oModelFieldInputTemplate = z.infer<typeof zChatGPT4oModelFieldInputTemplate>;
export const isChatGPT4oModelFieldInputInstance = buildInstanceTypeGuard(zChatGPT4oModelFieldInputInstance);
export const isChatGPT4oModelFieldInputTemplate =
buildTemplateTypeGuard<ChatGPT4oModelFieldInputTemplate>('ChatGPT4oModelField');
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1808,6 +1856,8 @@ export const zStatefulFieldValue = z.union([
zControlLoRAModelFieldValue,
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
zFloatGeneratorFieldValue,
@@ -1898,6 +1948,8 @@ const zStatefulFieldInputTemplate = z.union([
zControlLoRAModelFieldInputTemplate,
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -1,38 +1,63 @@
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import randomInt from 'common/util/randomInt';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import { assert } from 'tsafe';
import type { ConditioningNodes, NoiseNodes } from './types';
const getExtendedPrompts = (arg: {
seedBehaviour: SeedBehaviour;
iterations: number;
prompts: string[];
model: ModelIdentifierField;
}): string[] => {
const { seedBehaviour, iterations, prompts, model } = arg;
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o') {
return range(iterations).flatMap(() => prompts);
}
return prompts;
};
export const prepareLinearUIBatch = (
state: RootState,
g: Graph,
prepend: boolean,
noise: Invocation<NoiseNodes>,
posCond: Invocation<ConditioningNodes>,
origin: 'canvas' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): EnqueueBatchArg => {
export const prepareLinearUIBatch = (arg: {
state: RootState;
g: Graph;
prepend: boolean;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
origin: 'canvas' | 'workflows' | 'upscaling';
destination: 'canvas' | 'gallery';
}): EnqueueBatchArg => {
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
assert(model, 'No model found in state when preparing batch');
const data: Batch['data'] = [];
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
// add seeds first to ensure the output order groups the prompts
if (seedBehaviour === 'PER_PROMPT') {
if (seedFieldIdentifier && seedBehaviour === 'PER_PROMPT') {
const seeds = generateSeeds({
count: prompts.length * iterations,
start: shouldRandomizeSeed ? undefined : seed,
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
firstBatchDatumList.push({
node_path: noise.id,
field_name: 'seed',
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
items: seeds,
});
@@ -43,16 +68,18 @@ export const prepareLinearUIBatch = (
field_name: 'seed',
items: seeds,
});
} else {
} else if (seedFieldIdentifier && seedBehaviour === 'PER_ITERATION') {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
count: iterations,
start: shouldRandomizeSeed ? undefined : seed,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
secondBatchDatumList.push({
node_path: noise.id,
field_name: 'seed',
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
items: seeds,
});
@@ -66,12 +93,12 @@ export const prepareLinearUIBatch = (
data.push(secondBatchDatumList);
}
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
// zipped batch of prompts
firstBatchDatumList.push({
node_path: posCond.id,
field_name: 'prompt',
node_path: positivePromptFieldIdentifier.nodeId,
field_name: positivePromptFieldIdentifier.fieldName,
items: extendedPrompts,
});
@@ -83,9 +110,9 @@ export const prepareLinearUIBatch = (
items: extendedPrompts,
});
if (shouldConcatPrompts && model?.base === 'sdxl') {
if (shouldConcatPrompts && model.base === 'sdxl') {
firstBatchDatumList.push({
node_path: posCond.id,
node_path: positivePromptFieldIdentifier.nodeId,
field_name: 'style',
items: extendedPrompts,
});

View File

@@ -1,9 +1,9 @@
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@@ -12,7 +12,7 @@ import { getBoardField, selectPresetModifiedPrompts } from './graphBuilderUtils'
export const buildMultidiffusionUpscaleGraph = async (
state: RootState
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }> => {
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
const {
model,
upscaleCfgScale: cfg_scale,
@@ -243,5 +243,9 @@ export const buildMultidiffusionUpscaleGraph = async (
g.addEdge(controlNetCollector, 'collection', tiledMultidiffusion, 'control');
return { g, noise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -0,0 +1,126 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img' && generationMode !== 'img2img') {
throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode'));
}
log.debug({ generationMode }, 'Building GPT Image graph');
const model = selectMainModelConfig(state);
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(model, 'No model found in state');
assert(model.base === 'chatgpt-4o', 'Model is not a ChatGPT 4o model');
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
const validRefImages = canvas.referenceImages.entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isChatGPT4oReferenceImageConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0)
.toReversed(); // sends them in order they are displayed in the list
let reference_images: ImageField[] | undefined = undefined;
if (validRefImages.length > 0) {
reference_images = [];
for (const entity of validRefImages) {
assert(entity.ipAdapter.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.ipAdapter.image.image_name,
});
}
}
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('chatgpt_4o_txt2img_graph'));
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
reference_images,
use_cache: false,
is_intermediate,
board,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
};
}
if (generationMode === 'img2img') {
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_edit_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
base_image: { image_name },
reference_images,
use_cache: false,
is_intermediate,
board,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for ChatGPT ');
};

View File

@@ -19,7 +19,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -27,10 +27,7 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildCogView4Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'cogview4_denoise'>; posCond: Invocation<'cogview4_text_encoder'> }> => {
export const buildCogView4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building CogView4 graph');
@@ -186,5 +183,9 @@ export const buildCogView4Graph = async (
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise: denoise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -22,7 +22,11 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import {
type GraphBuilderReturn,
type ImageOutputNodes,
UnsupportedGenerationModeError,
} from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
@@ -34,10 +38,7 @@ import { addIPAdapters } from './addIPAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
export const buildFLUXGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building FLUX graph');
@@ -83,7 +84,9 @@ export const buildFLUXGraph = async (
//
// The other asserts above are just for sanity & type check and should never be hit, so they do not have
// translations.
assert(generationMode === 'inpaint' || generationMode === 'outpaint', t('toast.fluxFillIncompatibleWithT2IAndI2I'));
if (generationMode === 'txt2img' || generationMode === 'img2img') {
throw new UnsupportedGenerationModeError(t('toast.fluxFillIncompatibleWithT2IAndI2I'));
}
// FLUX Fill wants much higher guidance values than normal FLUX - silently "fix" the value for the user.
// TODO(psyche): Figure out a way to alert the user that this is happening - maybe return warnings from the graph
@@ -336,5 +339,9 @@ export const buildFLUXGraph = async (
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise: denoise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -0,0 +1,78 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagen3IncompatibleGenerationMode'));
}
log.debug({ generationMode }, 'Building Imagen3 graph');
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const model = selectMainModelConfig(state);
assert(model, 'No model found for Imagen3 graph');
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
assert(isImagen3AspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
enhance_prompt: true,
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
use_cache: false,
is_intermediate,
board,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
});
return {
g,
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
};

View File

@@ -23,7 +23,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -33,10 +33,7 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSD1Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => {
export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD1/SD2 graph');
@@ -316,5 +313,9 @@ export const buildSD1Graph = async (
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -18,7 +18,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -26,10 +26,7 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildSD3Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'sd3_denoise'>; posCond: Invocation<'sd3_text_encoder'> }> => {
export const buildSD3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD3 graph');
@@ -211,5 +208,9 @@ export const buildSD3Graph = async (
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise: denoise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -23,7 +23,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -33,10 +33,7 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSDXLGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => {
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SDXL graph');
@@ -323,5 +320,9 @@ export const buildSDXLGraph = async (
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise, posCond };
return {
g,
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
};
};

View File

@@ -1,3 +1,6 @@
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
export type ImageOutputNodes =
| 'l2i'
| 'img_nsfw'
@@ -23,11 +26,16 @@ export type MainModelLoaderNodes =
| 'cogview4_model_loader';
export type VaeSourceNodes = 'seamless' | 'vae_loader';
export type NoiseNodes = 'noise' | 'flux_denoise' | 'sd3_denoise' | 'cogview4_denoise';
export type ConditioningNodes =
| 'compel'
| 'sdxl_compel_prompt'
| 'flux_text_encoder'
| 'sd3_text_encoder'
| 'cogview4_text_encoder';
export type GraphBuilderReturn = {
g: Graph;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
};
export class UnsupportedGenerationModeError extends Error {
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}

View File

@@ -33,6 +33,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
ControlLoRAModelField: undefined,
SigLipModelField: undefined,
FluxReduxModelField: undefined,
Imagen3ModelField: undefined,
ChatGPT4oModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,

Some files were not shown because too many files have changed in this diff Show More