Compare commits

...

73 Commits
v6.7.0 ... main

Author SHA1 Message Date
psychedelicious
3707c3b034 fix(ui): do not bake opacity when rasterizing layer adjustments 2025-09-22 11:43:08 +10:00
Mary Hipp
5885db4ab5 ruff 2025-09-19 11:07:36 -04:00
Mary Hipp
36ed9b750d restore list_queue_items method 2025-09-19 11:07:36 -04:00
psychedelicious
3cec06f86e chore(ui): typegen 2025-09-19 22:13:12 +10:00
psychedelicious
28b5f7a1c5 feat(nodes): better deprecation handling for ui_type
- Move migration of model-specific ui_types into BaseInvocation. This
gives us access to the node and field names, so the warnings are more
useful to the end user.
- Ensure we serialize the fields' json_schema_extra with enum values.
This wasn't a problem until now, when it interferes with migrating
ui_type cleanly. It's a transparent change.
- Improve warnings when validating fields (which includes the ui_type
migration logic)
2025-09-19 22:13:12 +10:00
psychedelicious
22cbb23ae0 fix(ui): ref images for flux kontext & api models not parsed correctly 2025-09-19 21:40:17 +10:00
Riccardo Giovanetti
4d585e3eec translationBot(ui): update translation (Italian)
Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2127 of 2161 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-09-18 14:01:31 +10:00
psychedelicious
006b4356bb chore(ui): typegen 2025-09-18 12:39:27 +10:00
psychedelicious
da947866f2 fix(nodes): ensure SD2 models are pickable in loader/cnet nodes 2025-09-18 12:39:27 +10:00
psychedelicious
84a2cc6fc9 chore(ui): typegen 2025-09-18 12:39:27 +10:00
psychedelicious
b50534bb49 revert(nodes): do not deprecate ui_type for output fields! only deprecate the model ui types 2025-09-18 12:39:27 +10:00
psychedelicious
c305e79fee tests(ui): update tests to reflect new model parsing logic 2025-09-18 12:39:27 +10:00
psychedelicious
c32949d113 tidy(nodes): mark all UIType.*ModelField as deprecated 2025-09-18 12:39:27 +10:00
psychedelicious
87a98902da tidy(nodes): remove unused UIType.Video 2025-09-18 12:39:27 +10:00
psychedelicious
2857a446c9 docs(nodes): update docstrings for InputField 2025-09-18 12:39:27 +10:00
psychedelicious
035d9432bd feat(ui): support filtering on model format 2025-09-18 12:39:27 +10:00
psychedelicious
bdeb9fb1cf chore(ui): typegen 2025-09-18 12:39:27 +10:00
psychedelicious
dadff57061 feat(nodes): add ui_model_format filter for nodes 2025-09-18 12:39:27 +10:00
psychedelicious
480857ae4e fix(nodes): add base to SD1 model loader 2025-09-18 12:39:27 +10:00
psychedelicious
eaf0624004 feat(ui): remove explicit model type handling from workflow editor 2025-09-18 12:39:27 +10:00
psychedelicious
58bca1b9f4 feat(nodes): use new ui_model_[base|type|variant] on all core nodes 2025-09-18 12:39:27 +10:00
psychedelicious
54aa6908fa feat(ui): update invocation parsing to handle new ui_model_[base|type|variant] attrs 2025-09-18 12:39:27 +10:00
psychedelicious
e6d9daca96 chore(ui): typegen 2025-09-18 12:39:27 +10:00
psychedelicious
6e5a529cb7 feat(nodes): add ui_model_[base|type|variant] to InputField args for dynamic UI generation 2025-09-18 12:39:27 +10:00
Iq1pl
8c742a6e38 ruff format 2025-09-18 11:05:32 +10:00
Iq1pl
693373f1c1 Update ip_adapter.py
added support for NOOB-IPA-MARK1
2025-09-18 11:05:32 +10:00
Josh Corbett
4809080fd9 fix(ui): allow scrolling in ModelPane 2025-09-18 10:33:22 +10:00
psychedelicious
efcb1bea7f chore: bump version to v6.8.0rc1 2025-09-17 13:57:43 +10:00
psychedelicious
e0d7a401f3 feat(ui): make ref images croppable 2025-09-17 13:43:13 +10:00
psychedelicious
aac979e9a4 fix(ui): issue w/ setting initial aspect ratio in cropper 2025-09-17 13:43:13 +10:00
psychedelicious
3b0d7f076d tidy(ui): rename from "editor" to "cropper", minor cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
e1acbcdbd5 fix(ui): store floats for box 2025-09-17 13:43:13 +10:00
psychedelicious
7d9b81550b feat(ui): revert to original image when crop discarded 2025-09-17 13:43:13 +10:00
psychedelicious
6a447dd1fe refactor(ui): remove "apply", "start" and "cancel" concepts from editor 2025-09-17 13:43:13 +10:00
psychedelicious
c2dc63ddbc fix(ui): video graphs 2025-09-17 13:43:13 +10:00
psychedelicious
1bc689d531 docs(ui): add comments to startingframeimage 2025-09-17 13:43:13 +10:00
psychedelicious
4829975827 feat(ui): make the editor components not care about the image 2025-09-17 13:43:13 +10:00
psychedelicious
49da4e00c3 feat(ui): add concept for editable image state 2025-09-17 13:43:13 +10:00
psychedelicious
89dfe5e729 docs(ui): add comments to editor 2025-09-17 13:43:13 +10:00
psychedelicious
6816d366df tidy(ui): editor misc 2025-09-17 13:43:13 +10:00
psychedelicious
9d3d2a36c9 tidy(ui): editor listeners 2025-09-17 13:43:13 +10:00
psychedelicious
ed231044c8 refactor(ui): simplify crop constraints 2025-09-17 13:43:13 +10:00
psychedelicious
b51a232794 feat(ui): extract config to own obj 2025-09-17 13:43:13 +10:00
psychedelicious
4412143a6e feat(ui): clean up editor 2025-09-17 13:43:13 +10:00
psychedelicious
de11cafdb3 refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
4d9114aa7d refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
67e2da1ebf refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
33ecc591c3 refactor(ui): editor init 2025-09-17 13:43:13 +10:00
psychedelicious
b57459a226 chore(ui): lint 2025-09-17 13:43:13 +10:00
psychedelicious
01282b1c90 feat(ui): do not clear crop when canceling 2025-09-17 13:43:13 +10:00
psychedelicious
3f302906dc feat(ui): crop doesn't hide outside cropped region 2025-09-17 13:43:13 +10:00
psychedelicious
81d56596fb tidy(ui): cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
b536b0df0c feat(ui): misc iterate on editor 2025-09-17 13:43:13 +10:00
psychedelicious
692af1d93d feat(ui): type narrowing for editor output types 2025-09-17 13:43:13 +10:00
psychedelicious
bb7ef77b50 tidy(ui): lint/react conventions for editor component 2025-09-17 13:43:13 +10:00
psychedelicious
1862548573 feat(ui): image editor bg checkerboard pattern 2025-09-17 13:43:13 +10:00
psychedelicious
242c1b6350 feat(ui): tweak editor konva styles 2025-09-17 13:43:13 +10:00
psychedelicious
fc6e4bb04e tidy(ui): editor component cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
20841abca6 tidy(ui): editor cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
e8b69d99a4 chore(ui): lint 2025-09-17 13:43:13 +10:00
Mary Hipp
d6eaff8237 create editImageModal that takes an imageDTO, loads blob onto canvas, and allows cropping. cropped blob is uploaded as new asset 2025-09-17 13:43:13 +10:00
Mary Hipp
068b095956 show warning state with tooltip if starting frame image aspect ratio does not match the video output aspect ratio' 2025-09-17 13:43:13 +10:00
psychedelicious
f795a47340 tidy(ui): remove unused translation string 2025-09-16 15:04:03 +10:00
psychedelicious
df47345eb0 feat(ui): add translation strings for prompt history 2025-09-16 15:04:03 +10:00
psychedelicious
def04095a4 feat(ui): tweak prompt history styling 2025-09-16 15:04:03 +10:00
psychedelicious
28be8f0911 refactor(ui): simplify prompt history shortcuts 2025-09-16 15:04:03 +10:00
Kent Keirsey
b50c44bac0 handle potential for invalid list item 2025-09-16 15:04:03 +10:00
Kent Keirsey
b4ce0e02fc lint 2025-09-16 15:04:03 +10:00
Kent Keirsey
d6442d9a34 Prompt history shortcuts 2025-09-16 15:04:03 +10:00
Josh Corbett
4528bcafaf feat(model manager): add ModelFooter component and reusable ModelDeleteButton 2025-09-16 12:29:57 +10:00
Josh Corbett
8b82b81ee2 fix(ModelImage): change MODEL_IMAGE_THUMBNAIL_SIZE to a local constant 2025-09-16 12:29:57 +10:00
Josh Corbett
757acdd49e feat(model manager): 💄 update model manager ui, initial commit 2025-09-16 12:29:57 +10:00
psychedelicious
94b7cc583a fix(ui): do not reset params state on studio init nav to generate tab 2025-09-16 12:25:25 +10:00
132 changed files with 3686 additions and 3447 deletions

View File

@@ -36,6 +36,9 @@ from pydantic_core import PydanticUndefined
from invokeai.app.invocations.fields import (
FieldKind,
Input,
InputFieldJSONSchemaExtra,
UIType,
migrate_model_ui_type,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -256,7 +259,9 @@ class BaseInvocation(ABC, BaseModel):
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
json_schema_extra=InputFieldJSONSchemaExtra(
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
).model_dump(exclude_none=True),
)
use_cache: bool = Field(
default=True,
@@ -445,6 +450,15 @@ with warnings.catch_warnings():
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
"""Checks if a value is a member of an enum class."""
try:
enum_class(value)
return True
except ValueError:
return False
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
if not isinstance(field.json_schema_extra, dict):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
)
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
field_kind = field.json_schema_extra.get("field_kind", None)
# must have a field_kind
if not isinstance(field_kind, FieldKind):
if not is_enum_member(field_kind, FieldKind):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
)
if field_kind is FieldKind.Input and (
if field_kind == FieldKind.Input.value and (
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
):
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
)
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
# node attribute fields *must* be in the reserved list
if (
field_kind is FieldKind.NodeAttribute
field_kind == FieldKind.NodeAttribute.value
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
)
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
if ui_type is not None:
# There are 3 cases where we may need to take action:
#
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
# deprecated and should be migrated to:
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
# - ui_model_type=[ModelType.Main]
#
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
#
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
# Image fields are now automatically detected based on the field's type annotation. In this case, we
# ignore ui_type and log a warning.
#
# The cases must be checked in this order to ensure proper handling.
# Easier to work with as an enum
ui_type = UIType(ui_type)
# The enum member values are not always the same as their names - we want to log the name so the user can
# easily review their code and see where the deprecated enum member is used.
human_readable_name = f"UIType.{ui_type.name}"
# Case 1: migratable deprecated value
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
if did_migrate:
logger.warning(
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
)
field.json_schema_extra.pop("ui_type")
# Case 2: mutually exclusive with new fields
elif (
ui_model_base is not None
or ui_model_type is not None
or ui_model_variant is not None
or ui_model_format is not None
):
logger.warning(
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
)
field.json_schema_extra.pop("ui_type")
# Case 3: deprecated value that is not migratable
elif ui_type.startswith("DEPRECATED_"):
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
field.json_schema_extra.pop("ui_type")
return None

View File

@@ -5,7 +5,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
GlmEncoderField,
ModelIdentifierField,
@@ -14,6 +14,7 @@ from invokeai.app.invocations.model import (
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("cogview4_model_loader_output")
@@ -38,8 +39,9 @@ class CogView4ModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.cogview4_model,
ui_type=UIType.CogView4MainModel,
input=Input.Direct,
ui_model_base=BaseModelType.CogView4,
ui_model_type=ModelType.Main,
)
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:

View File

@@ -16,7 +16,6 @@ from invokeai.app.invocations.fields import (
ImageField,
InputField,
OutputField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
@@ -28,6 +27,7 @@ from invokeai.app.util.controlnet_utils import (
heuristic_resize_fast,
)
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class ControlField(BaseModel):
@@ -63,13 +63,17 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
@invocation(
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
)
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.ControlNet,
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -7,6 +7,13 @@ from pydantic_core import PydanticUndefined
from invokeai.app.util.metaenum import MetaEnum
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -39,47 +46,15 @@ class UIType(str, Enum, metaclass=MetaEnum):
used, and the type will be ignored. They are included here for backwards compatibility.
"""
# region Model Field Types
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
Gemini2_5Model = "Gemini2_5ModelField"
FluxKontextModel = "FluxKontextModelField"
Veo3Model = "Veo3ModelField"
RunwayModel = "RunwayModelField"
# endregion
# region Misc Field Types
Scheduler = "SchedulerField"
Any = "AnyField"
Video = "VideoField"
# endregion
# region Internal Field Types
_Collection = "CollectionField"
_CollectionItem = "CollectionItemField"
_IsIntermediate = "IsIntermediate"
# endregion
# region DEPRECATED
@@ -117,13 +92,44 @@ class UIType(str, Enum, metaclass=MetaEnum):
CollectionItem = "DEPRECATED_CollectionItem"
Enum = "DEPRECATED_Enum"
WorkflowField = "DEPRECATED_WorkflowField"
IsIntermediate = "DEPRECATED_IsIntermediate"
BoardField = "DEPRECATED_BoardField"
MetadataItem = "DEPRECATED_MetadataItem"
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
MetadataDict = "DEPRECATED_MetadataDict"
# Deprecated Model Field Types - use ui_model_[base|type|variant|format] instead
MainModel = "DEPRECATED_MainModelField"
CogView4MainModel = "DEPRECATED_CogView4MainModelField"
FluxMainModel = "DEPRECATED_FluxMainModelField"
SD3MainModel = "DEPRECATED_SD3MainModelField"
SDXLMainModel = "DEPRECATED_SDXLMainModelField"
SDXLRefinerModel = "DEPRECATED_SDXLRefinerModelField"
ONNXModel = "DEPRECATED_ONNXModelField"
VAEModel = "DEPRECATED_VAEModelField"
FluxVAEModel = "DEPRECATED_FluxVAEModelField"
LoRAModel = "DEPRECATED_LoRAModelField"
ControlNetModel = "DEPRECATED_ControlNetModelField"
IPAdapterModel = "DEPRECATED_IPAdapterModelField"
T2IAdapterModel = "DEPRECATED_T2IAdapterModelField"
T5EncoderModel = "DEPRECATED_T5EncoderModelField"
CLIPEmbedModel = "DEPRECATED_CLIPEmbedModelField"
CLIPLEmbedModel = "DEPRECATED_CLIPLEmbedModelField"
CLIPGEmbedModel = "DEPRECATED_CLIPGEmbedModelField"
SpandrelImageToImageModel = "DEPRECATED_SpandrelImageToImageModelField"
ControlLoRAModel = "DEPRECATED_ControlLoRAModelField"
SigLipModel = "DEPRECATED_SigLipModelField"
FluxReduxModel = "DEPRECATED_FluxReduxModelField"
LlavaOnevisionModel = "DEPRECATED_LLaVAModelField"
Imagen3Model = "DEPRECATED_Imagen3ModelField"
Imagen4Model = "DEPRECATED_Imagen4ModelField"
ChatGPT4oModel = "DEPRECATED_ChatGPT4oModelField"
Gemini2_5Model = "DEPRECATED_Gemini2_5ModelField"
FluxKontextModel = "DEPRECATED_FluxKontextModelField"
Veo3Model = "DEPRECATED_Veo3ModelField"
RunwayModel = "DEPRECATED_RunwayModelField"
# endregion
class UIComponent(str, Enum, metaclass=MetaEnum):
"""
@@ -409,10 +415,15 @@ class InputFieldJSONSchemaExtra(BaseModel):
ui_component: Optional[UIComponent] = None
ui_order: Optional[int] = None
ui_choice_labels: Optional[dict[str, str]] = None
ui_model_base: Optional[list[BaseModelType]] = None
ui_model_type: Optional[list[ModelType]] = None
ui_model_variant: Optional[list[ClipVariantType | ModelVariantType]] = None
ui_model_format: Optional[list[ModelFormat]] = None
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
use_enum_values=True,
)
@@ -465,16 +476,121 @@ class OutputFieldJSONSchemaExtra(BaseModel):
"""
field_kind: FieldKind
ui_hidden: bool
ui_type: Optional[UIType]
ui_order: Optional[int]
ui_hidden: bool = False
ui_order: Optional[int] = None
ui_type: Optional[UIType] = None
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
use_enum_values=True,
)
def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, Any]) -> bool:
"""Migrate deprecated model-specifier ui_type values to new-style ui_model_[base|type|variant|format] in json_schema_extra."""
if not isinstance(ui_type, UIType):
ui_type = UIType(ui_type)
ui_model_type: list[ModelType] | None = None
ui_model_base: list[BaseModelType] | None = None
ui_model_format: list[ModelFormat] | None = None
ui_model_variant: list[ClipVariantType | ModelVariantType] | None = None
match ui_type:
case UIType.MainModel:
ui_model_base = [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
ui_model_type = [ModelType.Main]
case UIType.CogView4MainModel:
ui_model_base = [BaseModelType.CogView4]
ui_model_type = [ModelType.Main]
case UIType.FluxMainModel:
ui_model_base = [BaseModelType.Flux]
ui_model_type = [ModelType.Main]
case UIType.SD3MainModel:
ui_model_base = [BaseModelType.StableDiffusion3]
ui_model_type = [ModelType.Main]
case UIType.SDXLMainModel:
ui_model_base = [BaseModelType.StableDiffusionXL]
ui_model_type = [ModelType.Main]
case UIType.SDXLRefinerModel:
ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
ui_model_type = [ModelType.Main]
case UIType.VAEModel:
ui_model_type = [ModelType.VAE]
case UIType.FluxVAEModel:
ui_model_base = [BaseModelType.Flux]
ui_model_type = [ModelType.VAE]
case UIType.LoRAModel:
ui_model_type = [ModelType.LoRA]
case UIType.ControlNetModel:
ui_model_type = [ModelType.ControlNet]
case UIType.IPAdapterModel:
ui_model_type = [ModelType.IPAdapter]
case UIType.T2IAdapterModel:
ui_model_type = [ModelType.T2IAdapter]
case UIType.T5EncoderModel:
ui_model_type = [ModelType.T5Encoder]
case UIType.CLIPEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
case UIType.CLIPLEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
ui_model_variant = [ClipVariantType.L]
case UIType.CLIPGEmbedModel:
ui_model_type = [ModelType.CLIPEmbed]
ui_model_variant = [ClipVariantType.G]
case UIType.SpandrelImageToImageModel:
ui_model_type = [ModelType.SpandrelImageToImage]
case UIType.ControlLoRAModel:
ui_model_type = [ModelType.ControlLoRa]
case UIType.SigLipModel:
ui_model_type = [ModelType.SigLIP]
case UIType.FluxReduxModel:
ui_model_type = [ModelType.FluxRedux]
case UIType.LlavaOnevisionModel:
ui_model_type = [ModelType.LlavaOnevision]
case UIType.Imagen3Model:
ui_model_base = [BaseModelType.Imagen3]
ui_model_type = [ModelType.Main]
case UIType.Imagen4Model:
ui_model_base = [BaseModelType.Imagen4]
ui_model_type = [ModelType.Main]
case UIType.ChatGPT4oModel:
ui_model_base = [BaseModelType.ChatGPT4o]
ui_model_type = [ModelType.Main]
case UIType.Gemini2_5Model:
ui_model_base = [BaseModelType.Gemini2_5]
ui_model_type = [ModelType.Main]
case UIType.FluxKontextModel:
ui_model_base = [BaseModelType.FluxKontext]
ui_model_type = [ModelType.Main]
case UIType.Veo3Model:
ui_model_base = [BaseModelType.Veo3]
ui_model_type = [ModelType.Video]
case UIType.RunwayModel:
ui_model_base = [BaseModelType.Runway]
ui_model_type = [ModelType.Video]
case _:
pass
did_migrate = False
if ui_model_type is not None:
json_schema_extra["ui_model_type"] = [m.value for m in ui_model_type]
did_migrate = True
if ui_model_base is not None:
json_schema_extra["ui_model_base"] = [m.value for m in ui_model_base]
did_migrate = True
if ui_model_format is not None:
json_schema_extra["ui_model_format"] = [m.value for m in ui_model_format]
did_migrate = True
if ui_model_variant is not None:
json_schema_extra["ui_model_variant"] = [m.value for m in ui_model_variant]
did_migrate = True
return did_migrate
def InputField(
# copied from pydantic's Field
# TODO: Can we support default_factory?
@@ -501,35 +617,63 @@ def InputField(
ui_hidden: Optional[bool] = None,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
ui_model_base: Optional[BaseModelType | list[BaseModelType]] = None,
ui_model_type: Optional[ModelType | list[ModelType]] = None,
ui_model_variant: Optional[ClipVariantType | ModelVariantType | list[ClipVariantType | ModelVariantType]] = None,
ui_model_format: Optional[ModelFormat | list[ModelFormat]] = None,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field)
that adds a few extra parameters to support graph execution and the node editor UI.
:param Input input: [Input.Any] The kind of input this field requires. \
`Input.Direct` means a value must be provided on instantiation. \
`Input.Connection` means the value must be provided by a connection. \
`Input.Any` means either will do.
If the field is a `ModelIdentifierField`, use the `ui_model_[base|type|variant|format]` args to filter the model list
in the Workflow Editor. Otherwise, use `ui_type` to provide extra type hints for the UI.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
Don't use both `ui_type` and `ui_model_[base|type|variant|format]` - if both are provided, a warning will be
logged and `ui_type` will be ignored.
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
The UI will always render a suitable component, but sometimes you want something different than the default. \
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
For this case, you could provide `UIComponent.Textarea`.
Args:
input: The kind of input this field requires.
- `Input.Direct` means a value must be provided on instantiation.
- `Input.Connection` means the value must be provided by a connection.
- `Input.Any` means either will do.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
ui_component: Optionally specifies a specific component to use in the UI. The UI will always render a suitable
component, but sometimes you want something different than the default. For example, a `string` field will
default to a single-line input, but you may want a multi-line textarea instead. In this case, you could use
`UIComponent.Textarea`.
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
ui_hidden: Specifies whether or not this field should be hidden in the UI.
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
ui_model_base: Specifies the base model architectures to filter the model list by in the Workflow Editor. For
example, `ui_model_base=BaseModelType.StableDiffusionXL` will show only SDXL architecture models. This arg is
only valid if this Input field is annotated as a `ModelIdentifierField`.
ui_model_type: Specifies the model type(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_type=ModelType.VAE` will show only VAE models. This arg is only valid if this Input field is
annotated as a `ModelIdentifierField`.
ui_model_variant: Specifies the model variant(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_variant=ModelVariantType.Inpainting` will show only inpainting models. This arg is only valid if this
Input field is annotated as a `ModelIdentifierField`.
ui_model_format: Specifies the model format(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_format=ModelFormat.Diffusers` will show only models in the diffusers format. This arg is only valid
if this Input field is annotated as a `ModelIdentifierField`.
ui_choice_labels: Specifies the labels to use for the choices in an enum field. If omitted, the enum values
will be used. This arg is only valid if the field is annotated with as a `Literal`. For example,
`Literal["choice1", "choice2", "choice3"]` with `ui_choice_labels={"choice1": "Choice 1", "choice2": "Choice 2",
"choice3": "Choice 3"}` will render a dropdown with the labels "Choice 1", "Choice 2" and "Choice 3".
"""
json_schema_extra_ = InputFieldJSONSchemaExtra(
@@ -537,8 +681,6 @@ def InputField(
field_kind=FieldKind.Input,
)
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
@@ -547,6 +689,28 @@ def InputField(
json_schema_extra_.ui_order = ui_order
if ui_choice_labels is not None:
json_schema_extra_.ui_choice_labels = ui_choice_labels
if ui_model_base is not None:
if isinstance(ui_model_base, list):
json_schema_extra_.ui_model_base = ui_model_base
else:
json_schema_extra_.ui_model_base = [ui_model_base]
if ui_model_type is not None:
if isinstance(ui_model_type, list):
json_schema_extra_.ui_model_type = ui_model_type
else:
json_schema_extra_.ui_model_type = [ui_model_type]
if ui_model_variant is not None:
if isinstance(ui_model_variant, list):
json_schema_extra_.ui_model_variant = ui_model_variant
else:
json_schema_extra_.ui_model_variant = [ui_model_variant]
if ui_model_format is not None:
if isinstance(ui_model_format, list):
json_schema_extra_.ui_model_format = ui_model_format
else:
json_schema_extra_.ui_model_format = [ui_model_format]
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
@@ -648,20 +812,20 @@ def OutputField(
"""
Creates an output field for an invocation output.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization)
that adds a few extra parameters to support graph execution and the node editor UI.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
Args:
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
ui_hidden: Specifies whether or not this field should be hidden in the UI.
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
"""
return Field(
default=default,
title=title,
@@ -679,9 +843,9 @@ def OutputField(
min_length=min_length,
max_length=max_length,
json_schema_extra=OutputFieldJSONSchemaExtra(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_type=ui_type,
field_kind=FieldKind.Output,
).model_dump(exclude_none=True),
)

View File

@@ -4,9 +4,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("flux_control_lora_loader_output")
@@ -29,7 +30,10 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
description=FieldDescriptions.control_lora_model,
title="Control LoRA",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.ControlLoRa,
)
image: ImageField = InputField(description="The image to encode.")
weight: float = InputField(description="The weight of the LoRA.", default=1.0)

View File

@@ -6,11 +6,12 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class FluxControlNetField(BaseModel):
@@ -57,7 +58,9 @@ class FluxControlNetInvocation(BaseInvocation):
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.ControlNet,
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -5,7 +5,7 @@ from pydantic import field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIType
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.ip_adapter import (
CLIP_VISION_MODEL_MAP,
IPAdapterField,
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation(
@@ -36,7 +37,10 @@ class FluxIPAdapterInvocation(BaseInvocation):
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
description="The IP-Adapter model.",
title="IP-Adapter Model",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.IPAdapter,
)
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")

View File

@@ -6,10 +6,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("flux_lora_loader_output")
@@ -36,7 +36,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(

View File

@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
@@ -17,7 +17,7 @@ from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
)
from invokeai.backend.model_manager.taxonomy import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("flux_model_loader_output")
@@ -46,23 +46,30 @@ class FluxModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.Main,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
description=FieldDescriptions.t5_encoder,
input=Input.Direct,
title="T5 Encoder",
ui_model_type=ModelType.T5Encoder,
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
ui_model_type=ModelType.CLIPEmbed,
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
description=FieldDescriptions.vae_model,
title="VAE",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.VAE,
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:

View File

@@ -18,7 +18,6 @@ from invokeai.app.invocations.fields import (
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
@@ -64,7 +63,8 @@ class FluxReduxInvocation(BaseInvocation):
redux_model: ModelIdentifierField = InputField(
description="The FLUX Redux model to use.",
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.FluxRedux,
)
downsampling_factor: int = InputField(
ge=1,

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@@ -85,7 +85,8 @@ class IPAdapterInvocation(BaseInvocation):
description="The IP-Adapter model.",
title="IP-Adapter Model",
ui_order=-1,
ui_type=UIType.IPAdapterModel,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.IPAdapter,
)
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",

View File

@@ -6,11 +6,12 @@ from pydantic import field_validator
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.util.devices import TorchDevice
@@ -34,7 +35,7 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
vllm_model: ModelIdentifierField = InputField(
title="LLaVA Model Type",
description=FieldDescriptions.vllm_model,
ui_type=UIType.LlavaOnevisionModel,
ui_model_type=ModelType.LlavaOnevision,
)
@field_validator("images", mode="before")

View File

@@ -53,7 +53,7 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.version import __version__
@@ -473,7 +473,6 @@ class MetadataToModelOutput(BaseInvocationOutput):
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.MainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@@ -488,7 +487,6 @@ class MetadataToSDXLModelOutput(BaseInvocationOutput):
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.SDXLMainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@@ -519,8 +517,7 @@ class MetadataToModelInvocation(BaseInvocation, WithMetadata):
input=Input.Direct,
)
default_value: ModelIdentifierField = InputField(
description="The default model to use if not found in the metadata",
ui_type=UIType.MainModel,
description="The default model to use if not found in the metadata", ui_model_type=ModelType.Main
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
@@ -575,7 +572,8 @@ class MetadataToSDXLModelInvocation(BaseInvocation, WithMetadata):
)
default_value: ModelIdentifierField = InputField(
description="The default SDXL Model to use if not found in the metadata",
ui_type=UIType.SDXLMainModel,
ui_model_type=ModelType.Main,
ui_model_base=BaseModelType.StableDiffusionXL,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)

View File

@@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -145,7 +145,7 @@ class ModelIdentifierInvocation(BaseInvocation):
@invocation(
"main_model_loader",
title="Main Model - SD1.5",
title="Main Model - SD1.5, SD2",
tags=["model"],
category="model",
version="1.0.4",
@@ -153,7 +153,11 @@ class ModelIdentifierInvocation(BaseInvocation):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
ui_model_type=ModelType.Main,
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@@ -187,7 +191,10 @@ class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.StableDiffusion1,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -250,7 +257,9 @@ class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@@ -332,7 +341,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.StableDiffusionXL,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -473,13 +485,26 @@ class SDXLLoRACollectionLoader(BaseInvocation):
@invocation(
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
"vae_loader",
title="VAE Model - SD1.5, SD2, SDXL, SD3, FLUX",
tags=["vae", "model"],
category="model",
version="1.0.4",
)
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
description=FieldDescriptions.vae_model,
title="VAE",
ui_model_base=[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusion3,
BaseModelType.Flux,
],
ui_model_type=ModelType.VAE,
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@@ -6,14 +6,14 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.taxonomy import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ClipVariantType, ModelType, SubModelType
@invocation_output("sd3_model_loader_output")
@@ -39,36 +39,43 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
ui_model_base=BaseModelType.StableDiffusion3,
ui_model_type=ModelType.Main,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
ui_model_type=ModelType.T5Encoder,
)
clip_l_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPLEmbedModel,
input=Input.Direct,
title="CLIP L Encoder",
default=None,
ui_model_type=ModelType.CLIPEmbed,
ui_model_variant=ClipVariantType.L,
)
clip_g_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_g_model,
ui_type=UIType.CLIPGEmbedModel,
input=Input.Direct,
title="CLIP G Encoder",
default=None,
ui_model_type=ModelType.CLIPEmbed,
ui_model_variant=ClipVariantType.G,
)
vae_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
description=FieldDescriptions.vae_model,
title="VAE",
default=None,
ui_model_base=BaseModelType.StableDiffusion3,
ui_model_type=ModelType.VAE,
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:

View File

@@ -1,8 +1,8 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("sdxl_model_loader_output")
@@ -29,7 +29,9 @@ class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
description=FieldDescriptions.sdxl_main_model,
ui_model_base=BaseModelType.StableDiffusionXL,
ui_model_type=ModelType.Main,
)
# TODO: precision?
@@ -67,7 +69,9 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
description=FieldDescriptions.sdxl_refiner_model,
ui_model_base=BaseModelType.StableDiffusionXLRefiner,
ui_model_type=ModelType.Main,
)
# TODO: precision?

View File

@@ -11,7 +11,6 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
@@ -19,6 +18,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@@ -33,7 +33,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_to_image_model: ModelIdentifierField = InputField(
title="Image-to-Image Model",
description=FieldDescriptions.spandrel_image_to_image_model,
ui_type=UIType.SpandrelImageToImageModel,
ui_model_type=ModelType.SpandrelImageToImage,
)
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."

View File

@@ -8,11 +8,12 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class T2IAdapterField(BaseModel):
@@ -60,7 +61,8 @@ class T2IAdapterInvocation(BaseInvocation):
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.T2IAdapter,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Coroutine, Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
@@ -22,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -135,6 +137,19 @@ class SessionQueueBase(ABC):
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items. Do not remove."""
pass
@abstractmethod
def list_all_queue_items(
self,

View File

@@ -34,6 +34,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -588,6 +589,59 @@ class SqliteSessionQueue(SessionQueueBase):
)
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,

View File

@@ -207,15 +207,24 @@ class IPAdapterPlusXL(IPAdapterPlus):
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
state_dict: IPAdapterStateDict = {
"ip_adapter": {},
"image_proj": {},
"adapter_modules": {}, # added for noobai-mark-ipa
"image_proj_model": {}, # added for noobai-mark-ipa
}
if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
if key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
elif key.startswith("image_proj_model."):
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = model[key]
elif key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("adapter_modules."):
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:

View File

@@ -104,6 +104,7 @@
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
"clipboard": "Clipboard",
"crop": "Crop",
"on": "On",
"off": "Off",
"or": "or",
@@ -242,7 +243,10 @@
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
"discard": "Discard",
"noPromptHistory": "No prompt history recorded.",
"noMatchingPrompts": "No matching prompts in history.",
"toSwitchBetweenPrompts": "to switch between prompts."
},
"queue": {
"queue": "Queue",
@@ -480,6 +484,14 @@
"title": "Focus Prompt",
"desc": "Move cursor focus to the positive prompt."
},
"promptHistoryPrev": {
"title": "Previous Prompt in History",
"desc": "When the prompt is focused, move to the previous (older) prompt in your history."
},
"promptHistoryNext": {
"title": "Next Prompt in History",
"desc": "When the prompt is focused, move to the next (newer) prompt in your history."
},
"toggleLeftPanel": {
"title": "Toggle Left Panel",
"desc": "Show or hide the left panel."
@@ -1258,6 +1270,7 @@
"infillColorValue": "Fill Color",
"info": "Info",
"startingFrameImage": "Start Frame",
"startingFrameImageAspectRatioWarning": "Image aspect ratio does not match the video aspect ratio ({{videoAspectRatio}}). This could lead to unexpected cropping during video generation.",
"invoke": {
"addingImagesTo": "Adding images to",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",

View File

@@ -131,7 +131,8 @@
"notInstalled": "Non $t(common.installed)",
"prevPage": "Pagina precedente",
"nextPage": "Pagina successiva",
"resetToDefaults": "Ripristina impostazioni predefinite"
"resetToDefaults": "Ripristina impostazioni predefinite",
"crop": "Ritaglia"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -278,6 +279,14 @@
"selectVideoTab": {
"title": "Seleziona la scheda Video",
"desc": "Seleziona la scheda Video."
},
"promptHistoryPrev": {
"title": "Prompt precedente nella cronologia",
"desc": "Quando il prompt è attivo, passa al prompt precedente (più vecchio) nella cronologia."
},
"promptHistoryNext": {
"title": "Prossimo prompt nella cronologia",
"desc": "Quando il prompt è attivo, passa al prompt successivo (più recente) nella cronologia."
}
},
"hotkeys": "Tasti di scelta rapida",
@@ -875,7 +884,8 @@
"video": "Video",
"resolution": "Risoluzione",
"downloadImage": "Scarica l'immagine",
"showOptionsPanel": "Mostra pannello laterale (O o T)"
"showOptionsPanel": "Mostra pannello laterale (O o T)",
"startingFrameImageAspectRatioWarning": "Le proporzioni dell'immagine non corrispondono alle proporzioni del video ({{videoAspectRatio}}). Ciò potrebbe causare ritagli imprevisti durante la generazione del video."
},
"settings": {
"models": "Modelli",
@@ -2095,7 +2105,10 @@
"generateFromImage": "Genera prompt dall'immagine",
"resultTitle": "Espansione del prompt completata",
"resultSubtitle": "Scegli come gestire il prompt espanso:",
"insert": "Inserisci"
"insert": "Inserisci",
"noPromptHistory": "Nessuna cronologia di prompt registrata.",
"noMatchingPrompts": "Nessun prompt corrispondente nella cronologia.",
"toSwitchBetweenPrompts": "per passare da un prompt all'altro."
},
"controlLayers": {
"addLayer": "Aggiungi Livello",
@@ -2791,7 +2804,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"items": [
"Seleziona oggetto v2: selezione degli oggetti migliorata con input di punti e riquadri o prompt di testo.",
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello."
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello.",
"Cronologia prompt: rivedi e richiama rapidamente i tuoi ultimi 100 prompt."
],
"watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente"
},

View File

@@ -2,6 +2,7 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { CropImageModal } from 'features/cropper/components/CropImageModal';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import { DeleteVideoModal } from 'features/deleteVideoModal/components/DeleteVideoModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
@@ -58,6 +59,7 @@ export const GlobalModalIsolator = memo(() => {
<CanvasPasteModal />
</CanvasManagerProviderGate>
<LoadWorkflowFromGraphModal />
<CropImageModal />
</>
);
});

View File

@@ -4,7 +4,6 @@ import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { withResultAsync } from 'common/util/result';
import { canvasReset } from 'features/controlLayers/store/actions';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -164,7 +163,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
case 'generation':
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
break;
case 'canvas':
// Go to the canvas tab, open the launchpad

View File

@@ -12,7 +12,13 @@ import {
} from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import {
getEntityIdentifier,
isFLUXReduxConfig,
isIPAdapterConfig,
isRegionalGuidanceFLUXReduxConfig,
isRegionalGuidanceIPAdapterConfig,
} from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -252,7 +258,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isIPAdapterConfig(config)) {
if (!isRegionalGuidanceIPAdapterConfig(config)) {
return;
}
@@ -295,7 +301,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isFLUXReduxConfig(config)) {
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
return;
}

View File

@@ -90,7 +90,7 @@ export const RasterLayerAdjustmentsPanel = memo(() => {
}
const rect = adapter.transformer.getRelativeRect();
try {
await adapter.renderer.rasterize({ rect, replaceObjects: true });
await adapter.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } });
// Clear adjustments after baking
dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null }));
} catch {

View File

@@ -1,12 +1,16 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { objectEquals } from '@observ33r/object-equals';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { CroppableImageWithDims } from 'features/controlLayers/store/types';
import { imageDTOToCroppableImage, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { Editor } from 'features/cropper/lib/editor';
import { cropImageModalApi } from 'features/cropper/store';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
@@ -14,14 +18,14 @@ import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PiArrowCounterClockwiseBold, PiCropBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
image: CroppableImageWithDims | null;
onChangeImage: (croppableImage: CroppableImageWithDims | null) => void;
dndTarget: T;
dndTargetData: ReturnType<T['getData']>;
};
@@ -38,20 +42,28 @@ export const RefImageImage = memo(
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const imageWithDims = image?.crop?.image ?? image?.original.image ?? null;
const croppedImageDTOReq = useGetImageDTOQuery(image?.crop?.image?.image_name ?? skipToken);
const originalImageDTOReq = useGetImageDTOQuery(image?.original.image.image_name ?? skipToken);
const [uploadImage] = useUploadImageMutation();
const originalImageDTO = originalImageDTOReq.currentData;
const croppedImageDTO = croppedImageDTOReq.currentData;
const imageDTO = croppedImageDTO ?? originalImageDTO;
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
if ((isConnected && croppedImageDTOReq.isError) || originalImageDTOReq.isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
}, [handleResetControlImage, isConnected, croppedImageDTOReq.isError, originalImageDTOReq.isError]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
onChangeImage(imageDTOToCroppableImage(imageDTO));
},
[onChangeImage]
);
@@ -70,13 +82,67 @@ export const RefImageImage = memo(
}
}, [imageDTO, isStaging, store, tab]);
const edit = useCallback(() => {
if (!originalImageDTO) {
return;
}
// We will create a new editor instance each time the user wants to edit
const editor = new Editor();
// When the user applies the crop, we will upload the cropped image and store the applied crop box so if the user
// re-opens the editor they see the same crop
const onApplyCrop = async () => {
const box = editor.getCropBox();
if (objectEquals(box, image?.crop?.box)) {
// If the box hasn't changed, don't do anything
return;
}
if (!box || objectEquals(box, { x: 0, y: 0, width: originalImageDTO.width, height: originalImageDTO.height })) {
// There is a crop applied but it is the whole iamge - revert to original image
onChangeImage(imageDTOToCroppableImage(originalImageDTO));
return;
}
const blob = await editor.exportImage('blob');
const file = new File([blob], 'image.png', { type: 'image/png' });
const newCroppedImageDTO = await uploadImage({
file,
is_intermediate: true,
image_category: 'user',
}).unwrap();
onChangeImage(
imageDTOToCroppableImage(originalImageDTO, {
image: imageDTOToImageWithDims(newCroppedImageDTO),
box,
ratio: editor.getCropAspectRatio(),
})
);
};
const onReady = async () => {
const initial = image?.crop ? { cropBox: image.crop.box, aspectRatio: image.crop.ratio } : undefined;
// Load the image into the editor and open the modal once it's ready
await editor.loadImage(originalImageDTO.image_url, initial);
};
cropImageModalApi.open({ editor, onApplyCrop, onReady });
}, [image?.crop, onChangeImage, originalImageDTO, uploadImage]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
<Flex
position="relative"
w="full"
h="full"
alignItems="center"
data-error={!imageDTO && !imageWithDims?.image_name}
>
{!imageDTO && (
<UploadImageIconButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
isError={!imageDTO && !imageWithDims?.image_name}
onUpload={onUpload}
fontSize={36}
/>
@@ -99,6 +165,15 @@ export const RefImageImage = memo(
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
<Flex position="absolute" flexDir="column" top={2} insetInlineStart={2} gap={1}>
<DndImageIcon
onClick={edit}
icon={<PiCropBold size={16} />}
tooltip={t('common.crop')}
isDisabled={!imageDTO}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />

View File

@@ -13,7 +13,7 @@ import {
selectRefImageEntityIds,
selectSelectedRefEntityId,
} from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -92,7 +92,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
({
onUpload: (imageDTO: ImageDTO) => {
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
},
allowMultiple: false,

View File

@@ -1,6 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
@@ -15,7 +14,7 @@ import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
@@ -72,7 +71,8 @@ export const RefImagePreview = memo(() => {
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
const { data: imageDTO } = useGetImageDTOQuery(entity.config.image?.image_name ?? skipToken);
const imageDTO = useImageDTOFromCroppableImage(entity.config.image);
const sx = useMemo(() => {
if (!isIPAdapterConfig(entity.config)) {
@@ -145,7 +145,7 @@ export const RefImagePreview = memo(() => {
overflow="hidden"
>
<Image
src={imageDTO?.thumbnail_url}
src={imageDTO?.image_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}

View File

@@ -30,6 +30,7 @@ import {
} from 'features/controlLayers/store/refImagesSlice';
import type {
CLIPVisionModelV2,
CroppableImageWithDims,
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
IPMethodV2,
} from 'features/controlLayers/store/types';
@@ -42,7 +43,6 @@ import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
@@ -104,15 +104,19 @@ const RefImageSettingsContent = memo(() => {
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(refImageImageChanged({ id, imageDTO }));
(croppableImage: CroppableImageWithDims | null) => {
dispatch(refImageImageChanged({ id, croppableImage }));
},
[dispatch, id]
);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ id }, config.image?.image_name),
[id, config.image?.image_name]
() =>
setGlobalReferenceImageDndTarget.getData(
{ id },
config.image?.crop?.image.image_name ?? config.image?.original.image.image_name
),
[id, config.image?.crop?.image.image_name, config.image?.original.image.image_name]
);
const isFLUX = useAppSelector(selectIsFLUX);

View File

@@ -6,7 +6,6 @@ import { FLUXReduxImageInfluence } from 'features/controlLayers/components/commo
import { IPAdapterCLIPVisionModel } from 'features/controlLayers/components/common/IPAdapterCLIPVisionModel';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/RefImage/IPAdapterMethod';
import { RefImageImage } from 'features/controlLayers/components/RefImage/RefImageImage';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/RegionalGuidance/RegionalReferenceImageModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -37,6 +36,8 @@ import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { RegionalGuidanceRefImageImage } from './RegionalGuidanceRefImageImage';
type Props = {
referenceImageId: string;
};
@@ -114,7 +115,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
{ entityIdentifier, referenceImageId },
config.image?.image_name
),
[entityIdentifier, config.image?.image_name, referenceImageId]
[entityIdentifier, config.image, referenceImageId]
);
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
@@ -170,7 +171,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<RefImageImage
<RegionalGuidanceRefImageImage
image={config.image}
onChangeImage={onChangeImage}
dndTarget={setRegionalGuidanceReferenceImageDndTarget}

View File

@@ -0,0 +1,103 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
dndTarget: typeof setRegionalGuidanceReferenceImageDndTarget;
dndTargetData: ReturnType<(typeof setRegionalGuidanceReferenceImageDndTarget)['getData']>;
};
export const RegionalGuidanceRefImageImage = memo(({ image, onChangeImage, dndTarget, dndTargetData }: Props) => {
const { t } = useTranslation();
const store = useAppStore();
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
},
[onChangeImage]
);
const recallSizeAndOptimize = useCallback(() => {
if (!imageDTO || (tab === 'canvas' && isStaging)) {
return;
}
const { width, height } = imageDTO;
if (tab === 'canvas') {
store.dispatch(bboxSizeRecalled({ width, height }));
store.dispatch(bboxSizeOptimized());
} else if (tab === 'generate') {
store.dispatch(sizeRecalled({ width, height }));
store.dispatch(sizeOptimized());
}
}, [imageDTO, isStaging, store, tab]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageIconButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
onUpload={onUpload}
fontSize={36}
/>
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={recallSizeAndOptimize}
icon={<PiRulerBold size={16} />}
tooltip={t('parameters.useSize')}
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
</Flex>
);
});
RegionalGuidanceRefImageImage.displayName = 'RegionalGuidanceRefImageImage';

View File

@@ -30,6 +30,7 @@ import type {
FluxKontextReferenceImageConfig,
Gemini2_5ReferenceImageConfig,
IPAdapterConfig,
RegionalGuidanceIPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import {
@@ -38,6 +39,7 @@ import {
initialFluxKontextReferenceImage,
initialGemini2_5ReferenceImage,
initialIPAdapter,
initialRegionalGuidanceIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
@@ -125,7 +127,7 @@ export const getDefaultRefImageConfig = (
return config;
};
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): IPAdapterConfig => {
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): RegionalGuidanceIPAdapterConfig => {
// Regional guidance ref images do not support ChatGPT-4o, so we always return the IP Adapter config.
const state = getState();
@@ -138,7 +140,7 @@ export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState):
const modelConfig = ipAdapterModelConfigs.find((m) => m.base === base);
// Clone the initial IP Adapter config and set the model if available.
const config = deepClone(initialIPAdapter);
const config = deepClone(initialRegionalGuidanceIPAdapter);
if (modelConfig) {
config.model = zModelIdentifierField.parse(modelConfig);

View File

@@ -32,7 +32,12 @@ import type {
RefImageState,
RegionalGuidanceRefImageState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import {
imageDTOToCroppableImage,
imageDTOToImageObject,
imageDTOToImageWithDims,
initialControlNet,
} from 'features/controlLayers/store/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -209,7 +214,7 @@ export const useNewGlobalReferenceImageFromBbox = () => {
const overrides: Partial<RefImageState> = {
config: {
...getDefaultRefImageConfig(getState),
image: imageDTOToImageWithDims(imageDTO),
image: imageDTOToCroppableImage(imageDTO),
},
};
dispatch(refImageAdded({ overrides }));
@@ -312,7 +317,7 @@ export const usePullBboxIntoGlobalReferenceImage = (id: string) => {
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO, _: Rect) => {
dispatch(refImageImageChanged({ id, imageDTO }));
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
};
return {

View File

@@ -82,10 +82,10 @@ import {
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isFLUXReduxConfig,
isGemini2_5AspectRatioID,
isImagenAspectRatioID,
isIPAdapterConfig,
isRegionalGuidanceFLUXReduxConfig,
isRegionalGuidanceIPAdapterConfig,
zCanvasState,
} from './types';
import {
@@ -99,6 +99,7 @@ import {
initialControlNet,
initialFLUXRedux,
initialIPAdapter,
initialRegionalGuidanceIPAdapter,
initialT2IAdapter,
makeDefaultRasterLayerAdjustments,
} from './util';
@@ -804,7 +805,7 @@ const slice = createSlice({
if (!entity) {
return;
}
const config = { id: referenceImageId, config: deepClone(initialIPAdapter) };
const config = { id: referenceImageId, config: deepClone(initialRegionalGuidanceIPAdapter) };
merge(config, overrides);
entity.referenceImages.push(config);
},
@@ -847,7 +848,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
@@ -864,7 +865,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.beginEndStepPct = beginEndStepPct;
@@ -880,7 +881,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.method = method;
@@ -899,7 +900,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isFLUXReduxConfig(referenceImage.config)) {
if (!isRegionalGuidanceFLUXReduxConfig(referenceImage.config)) {
return;
}
@@ -928,7 +929,7 @@ const slice = createSlice({
return;
}
if (isIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
// Switching from ip_adapter to flux_redux
referenceImage.config = {
...initialFLUXRedux,
@@ -938,7 +939,7 @@ const slice = createSlice({
return;
}
if (isFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
if (isRegionalGuidanceFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
// Switching from flux_redux to ip_adapter
referenceImage.config = {
...initialIPAdapter,
@@ -948,7 +949,7 @@ const slice = createSlice({
return;
}
if (isIPAdapterConfig(referenceImage.config)) {
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
referenceImage.config.model = zModelIdentifierField.parse(modelConfig);
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
@@ -971,7 +972,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.clipVisionModel = clipVisionModel;

View File

@@ -199,11 +199,7 @@ const slice = createSlice({
return;
}
if (state.positivePromptHistory.includes(prompt)) {
return;
}
state.positivePromptHistory.unshift(prompt);
state.positivePromptHistory = [prompt, ...state.positivePromptHistory.filter((p) => p !== prompt)];
if (state.positivePromptHistory.length > MAX_POSITIVE_PROMPT_HISTORY) {
state.positivePromptHistory = state.positivePromptHistory.slice(0, MAX_POSITIVE_PROMPT_HISTORY);

View File

@@ -6,13 +6,16 @@ import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
import type {
CroppableImageWithDims,
FLUXReduxImageInfluence,
RefImagesState,
} from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
import { assert } from 'tsafe';
@@ -22,7 +25,6 @@ import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialFluxKontextReferenceImage,
initialFLUXRedux,
@@ -65,13 +67,13 @@ const slice = createSlice({
state.entities.push(...entities);
}
},
refImageImageChanged: (state, action: PayloadActionWithId<{ imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => {
const { id, croppableImage } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.config.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
entity.config.image = croppableImage;
},
refImageIPAdapterMethodChanged: (state, action: PayloadActionWithId<{ method: IPMethodV2 }>) => {
const { id, method } = action.payload;

View File

@@ -37,6 +37,45 @@ export const zImageWithDims = z.object({
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zCropBox = z.object({
x: z.number().min(0),
y: z.number().min(0),
width: z.number().positive(),
height: z.number().positive(),
});
// This new schema is an extension of zImageWithDims, with an optional crop field.
//
// When we added cropping support to certain entities (e.g. Ref Images, video Starting Frame Image), we changed
// their schemas from using zImageWithDims to this new schema. To support loading pre-existing entities that
// were created before cropping was supported, we can use zod's preprocess to transform old data into the new format.
// Its essentially a data migration step.
//
// This parsing happens currently in two places:
// - Recalling metadata.
// - Loading/rehydrating persisted client state from storage.
export const zCroppableImageWithDims = z.preprocess(
(val) => {
try {
const imageWithDims = zImageWithDims.parse(val);
const migrated = { original: { image: deepClone(imageWithDims) } };
return migrated;
} catch {
return val;
}
},
z.object({
original: z.object({ image: zImageWithDims }),
crop: z
.object({
box: zCropBox,
ratio: z.number().gt(0).nullable(),
image: zImageWithDims,
})
.optional(),
})
);
export type CroppableImageWithDims = z.infer<typeof zCroppableImageWithDims>;
const zImageWithDimsDataURL = z.object({
dataURL: z.string(),
width: z.number().int().positive(),
@@ -235,7 +274,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
@@ -244,21 +283,39 @@ const zIPAdapterConfig = z.object({
});
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
const zRegionalGuidanceIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
clipVisionModel: zCLIPVisionModelV2,
});
export type RegionalGuidanceIPAdapterConfig = z.infer<typeof zRegionalGuidanceIPAdapterConfig>;
const zFLUXReduxImageInfluence = z.enum(['lowest', 'low', 'medium', 'high', 'highest']);
export const isFLUXReduxImageInfluence = (v: unknown): v is FLUXReduxImageInfluence =>
zFLUXReduxImageInfluence.safeParse(v).success;
export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zRegionalGuidanceFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
type RegionalGuidanceFLUXReduxConfig = z.infer<typeof zRegionalGuidanceFLUXReduxConfig>;
const zChatGPT4oReferenceImageConfig = z.object({
type: z.literal('chatgpt_4o_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
/**
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
@@ -270,14 +327,14 @@ export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceIm
const zGemini2_5ReferenceImageConfig = z.object({
type: z.literal('gemini_2_5_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
});
export type Gemini2_5ReferenceImageConfig = z.infer<typeof zGemini2_5ReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -307,6 +364,7 @@ export const isIPAdapterConfig = (config: RefImageState['config']): config is IP
export const isFLUXReduxConfig = (config: RefImageState['config']): config is FLUXReduxConfig =>
config.type === 'flux_redux';
export const isChatGPT4oReferenceImageConfig = (
config: RefImageState['config']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
@@ -326,10 +384,18 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
const zRegionalGuidanceRefImageState = z.object({
id: zId,
config: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
config: z.discriminatedUnion('type', [zRegionalGuidanceIPAdapterConfig, zRegionalGuidanceFLUXReduxConfig]),
});
export type RegionalGuidanceRefImageState = z.infer<typeof zRegionalGuidanceRefImageState>;
export const isRegionalGuidanceIPAdapterConfig = (
config: RegionalGuidanceRefImageState['config']
): config is RegionalGuidanceIPAdapterConfig => config.type === 'ip_adapter';
export const isRegionalGuidanceFLUXReduxConfig = (
config: RegionalGuidanceRefImageState['config']
): config is RegionalGuidanceFLUXReduxConfig => config.type === 'flux_redux';
const zCanvasRegionalGuidanceState = zCanvasEntityBase.extend({
type: z.literal('regional_guidance'),
position: zCoordinate,

View File

@@ -10,6 +10,7 @@ import type {
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
CroppableImageWithDims,
FluxKontextReferenceImageConfig,
FLUXReduxConfig,
Gemini2_5ReferenceImageConfig,
@@ -17,6 +18,7 @@ import type {
IPAdapterConfig,
RasterLayerAdjustments,
RefImageState,
RegionalGuidanceIPAdapterConfig,
RgbColor,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
@@ -45,6 +47,21 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height,
});
export const imageDTOToCroppableImage = (
originalImageDTO: ImageDTO,
crop?: CroppableImageWithDims['crop']
): CroppableImageWithDims => {
const { image_name, width, height } = originalImageDTO;
const val: CroppableImageWithDims = {
original: { image: { image_name, width, height } },
};
if (crop) {
val.crop = deepClone(crop);
}
return val;
};
export const imageDTOToImageField = ({ image_name }: ImageDTO): ImageField => ({ image_name });
const DEFAULT_RG_MASK_FILL_COLORS: RgbColor[] = [
@@ -79,6 +96,15 @@ export const initialIPAdapter: IPAdapterConfig = {
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialRegionalGuidanceIPAdapter: RegionalGuidanceIPAdapterConfig = {
type: 'ip_adapter',
image: null,
model: null,
beginEndStepPct: [0, 1],
method: 'full',
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialFLUXRedux: FLUXReduxConfig = {
type: 'flux_redux',
image: null,

View File

@@ -0,0 +1,215 @@
import {
Button,
ButtonGroup,
Divider,
Flex,
FormControl,
FormLabel,
Select,
Spacer,
Text,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { AspectRatioID } from 'features/controlLayers/store/types';
import { ASPECT_RATIO_MAP, isAspectRatioID } from 'features/controlLayers/store/types';
import type { CropBox } from 'features/cropper/lib/editor';
import { cropImageModalApi, type CropImageModalState } from 'features/cropper/store';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import React, { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { objectEntries } from 'tsafe';
type Props = {
editor: CropImageModalState['editor'];
onApplyCrop: CropImageModalState['onApplyCrop'];
onReady: CropImageModalState['onReady'];
};
const getAspectRatioString = (ratio: number | null): AspectRatioID => {
if (!ratio) {
return 'Free';
}
const entries = objectEntries(ASPECT_RATIO_MAP);
for (const [key, value] of entries) {
if (value.ratio === ratio) {
return key;
}
}
return 'Free';
};
export const CropImageEditor = memo(({ editor, onApplyCrop, onReady }: Props) => {
const containerRef = useRef<HTMLDivElement>(null);
const [zoom, setZoom] = useState(100);
const [cropBox, setCropBox] = useState<CropBox | null>(null);
const [aspectRatio, setAspectRatio] = useState<string>('free');
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage] = useUploadImageMutation({ fixedCacheKey: 'editorContainer' });
const setup = useCallback(
async (container: HTMLDivElement) => {
editor.init(container);
editor.onZoomChange((zoom) => {
setZoom(zoom);
});
editor.onCropBoxChange((crop) => {
setCropBox(crop);
});
editor.onAspectRatioChange((ratio) => {
setAspectRatio(getAspectRatioString(ratio));
});
await onReady();
editor.fitToContainer();
},
[editor, onReady]
);
useEffect(() => {
const container = containerRef.current;
if (!container) {
return;
}
setup(container);
const handleResize = () => {
editor.resize(container.clientWidth, container.clientHeight);
};
const resizeObserver = new ResizeObserver(handleResize);
resizeObserver.observe(container);
return () => {
resizeObserver.disconnect();
};
}, [editor, setup]);
const handleAspectRatioChange = useCallback(
(e: React.ChangeEvent<HTMLSelectElement>) => {
const newRatio = e.target.value;
if (!isAspectRatioID(newRatio)) {
return;
}
setAspectRatio(newRatio);
if (newRatio === 'Free') {
editor.setCropAspectRatio(null);
} else {
editor.setCropAspectRatio(ASPECT_RATIO_MAP[newRatio]?.ratio ?? null);
}
},
[editor]
);
const handleResetCrop = useCallback(() => {
editor.resetCrop();
}, [editor]);
const handleApplyCrop = useCallback(async () => {
await onApplyCrop();
cropImageModalApi.close();
}, [onApplyCrop]);
const handleCancelCrop = useCallback(() => {
cropImageModalApi.close();
}, []);
const handleExport = useCallback(async () => {
try {
const blob = await editor.exportImage('blob');
const file = new File([blob], 'image.png', { type: 'image/png' });
await uploadImage({
file,
is_intermediate: false,
image_category: 'user',
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
} catch (err) {
if (err instanceof Error && err.message.includes('tainted')) {
alert(
'Cannot export image: The image is from a different domain (CORS issue). To fix this:\n\n1. Load images from the same domain\n2. Use images from CORS-enabled sources\n3. Upload a local image file instead'
);
} else {
alert(`Export failed: ${err instanceof Error ? err.message : String(err)}`);
}
}
}, [autoAddBoardId, editor, uploadImage]);
const zoomIn = useCallback(() => {
editor.zoomIn();
}, [editor]);
const zoomOut = useCallback(() => {
editor.zoomOut();
}, [editor]);
const fitToContainer = useCallback(() => {
editor.fitToContainer();
}, [editor]);
const resetView = useCallback(() => {
editor.resetView();
}, [editor]);
return (
<Flex w="full" h="full" flexDir="column" gap={4}>
<Flex gap={2} alignItems="center">
<FormControl flex={1}>
<FormLabel>Aspect Ratio:</FormLabel>
<Select size="sm" value={aspectRatio} onChange={handleAspectRatioChange} w={32}>
<option value="Free">Free</option>
<option value="16:9">16:9</option>
<option value="3:2">3:2</option>
<option value="4:3">4:3</option>
<option value="1:1">1:1</option>
<option value="3:4">3:4</option>
<option value="2:3">2:3</option>
<option value="9:16">9:16</option>
</Select>
</FormControl>
<Spacer />
<ButtonGroup size="sm" isAttached={false}>
<Button onClick={fitToContainer}>Fit View</Button>
<Button onClick={resetView}>Reset View</Button>
<Button onClick={zoomIn}>Zoom In</Button>
<Button onClick={zoomOut}>Zoom Out</Button>
</ButtonGroup>
<Spacer />
<ButtonGroup size="sm" isAttached={false}>
<Button onClick={handleApplyCrop}>Apply</Button>
<Button onClick={handleResetCrop}>Reset</Button>
<Button onClick={handleCancelCrop}>Cancel</Button>
<Button onClick={handleExport}>Save to Assets</Button>
</ButtonGroup>
</Flex>
<Flex position="relative" w="full" h="full" bg="base.900">
<Flex position="absolute" inset={0} ref={containerRef} />
</Flex>
<Flex gap={2} color="base.300">
<Text>Mouse wheel: Zoom</Text>
<Divider orientation="vertical" />
<Text>Space + Drag: Pan</Text>
<Divider orientation="vertical" />
<Text>Drag crop box or handles to adjust</Text>
{cropBox && (
<>
<Divider orientation="vertical" />
<Text>
X: {Math.round(cropBox.x)}, Y: {Math.round(cropBox.y)}, Width: {Math.round(cropBox.width)}, Height:{' '}
{Math.round(cropBox.height)}
</Text>
</>
)}
<Spacer key="help-spacer" />
<Text key="help-zoom">Zoom: {Math.round(zoom * 100)}%</Text>
</Flex>
</Flex>
);
});
CropImageEditor.displayName = 'CropImageEditor';

View File

@@ -0,0 +1,29 @@
import { Modal, ModalBody, ModalContent, ModalHeader, ModalOverlay } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { cropImageModalApi } from 'features/cropper/store';
import { memo } from 'react';
import { CropImageEditor } from './CropImageEditor';
export const CropImageModal = memo(() => {
const state = useStore(cropImageModalApi.$state);
if (!state) {
return null;
}
return (
// This modal is always open when this component is rendered
<Modal isOpen={true} onClose={cropImageModalApi.close} isCentered useInert={false} size="full">
<ModalOverlay />
<ModalContent minH="unset" minW="unset" maxH="90vh" maxW="90vw" w="full" h="full" borderRadius="base">
<ModalHeader>Crop Image</ModalHeader>
<ModalBody px={4} pb={4} pt={0}>
<CropImageEditor editor={state.editor} onApplyCrop={state.onApplyCrop} onReady={state.onReady} />
</ModalBody>
</ModalContent>
</Modal>
);
});
CropImageModal.displayName = 'CropImageModal';

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
import type { Editor } from 'features/cropper/lib/editor';
import { atom } from 'nanostores';
export type CropImageModalState = {
editor: Editor;
onApplyCrop: () => Promise<void> | void;
onReady: () => Promise<void> | void;
};
const $state = atom<CropImageModalState | null>(null);
const open = (state: CropImageModalState) => {
$state.set(state);
};
const close = () => {
const state = $state.get();
state?.editor.destroy();
$state.set(null);
};
export const cropImageModalApi = {
$state,
open,
close,
};

View File

@@ -236,8 +236,11 @@ const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, image
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, image_name: string) => {
selectReferenceImageEntities(state).forEach((entity) => {
if (entity.config.image?.image_name === image_name) {
dispatch(refImageImageChanged({ id: entity.id, imageDTO: null }));
if (
entity.config.image?.original.image.image_name === image_name ||
entity.config.image?.crop?.image.image_name === image_name
) {
dispatch(refImageImageChanged({ id: entity.id, croppableImage: null }));
}
});
};
@@ -284,7 +287,10 @@ export const getImageUsage = (
const isUpscaleImage = upscale.upscaleInitialImage?.image_name === image_name;
const isReferenceImage = refImages.entities.some(({ config }) => config.image?.image_name === image_name);
const isReferenceImage = refImages.entities.some(
({ config }) =>
config.image?.original.image.image_name === image_name || config.image?.crop?.image.image_name === image_name
);
const isRasterLayerImage = canvas.rasterLayers.entities.some(({ objects }) =>
objects.some((obj) => obj.type === 'image' && 'image_name' in obj.image && obj.image.image_name === image_name)

View File

@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent } from 'react';
import { memo } from 'react';
const sx: SystemStyleObject = {
export const imageButtonSx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
@@ -31,7 +31,7 @@ export const DndImageIcon = memo((props: Props) => {
aria-label={tooltip}
icon={icon}
variant="link"
sx={sx}
sx={imageButtonSx}
data-testid={tooltip}
{...rest}
/>

View File

@@ -4,7 +4,7 @@ import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerH
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import type { BoardId } from 'features/gallery/store/types';
import {
@@ -211,7 +211,7 @@ export const addGlobalReferenceImageDndTarget: DndTarget<
handler: ({ sourceData, dispatch, getState }) => {
const { imageDTO } = sourceData.payload;
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
},
};
@@ -641,7 +641,7 @@ export const videoFrameFromImageDndTarget: DndTarget<VideoFrameFromImageDndTarge
},
handler: ({ sourceData, dispatch }) => {
const { imageDTO } = sourceData.payload;
dispatch(startingFrameImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
},
};
//#endregion

View File

@@ -1,4 +1,5 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
import { startingFrameImageChanged } from 'features/parameters/store/videoSlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
@@ -13,7 +14,7 @@ export const ContextMenuItemSendToVideo = memo(() => {
const dispatch = useDispatch();
const onClick = useCallback(() => {
dispatch(startingFrameImageChanged(imageDTO));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
navigationApi.switchToTab('video');
}, [imageDTO, dispatch]);

View File

@@ -2,7 +2,7 @@ import { MenuItem } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/storeHooks';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -17,7 +17,7 @@ export const ContextMenuItemUseAsRefImage = memo(() => {
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -26,7 +26,12 @@ import type {
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import {
imageDTOToCroppableImage,
imageDTOToImageObject,
imageDTOToImageWithDims,
initialControlNet,
} from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
@@ -44,7 +49,7 @@ import { assert } from 'tsafe';
export const setGlobalReferenceImage = (arg: { imageDTO: ImageDTO; id: string; dispatch: AppDispatch }) => {
const { imageDTO, id, dispatch } = arg;
dispatch(refImageImageChanged({ id, imageDTO }));
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
};
export const setRegionalGuidanceReferenceImage = (arg: {

View File

@@ -975,7 +975,7 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
for (const refImage of parsed) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
await throwIfImageDoesNotExist(refImage.config.image.original.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);

View File

@@ -35,7 +35,7 @@ export const LaunchpadForm = memo(() => {
return (
<Flex flexDir="column" height="100%" gap={3}>
<ScrollableContent>
<Flex flexDir="column" gap={6} p={3}>
<Flex flexDir="column" gap={6} py={2}>
{/* Welcome Section */}
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Heading size="md">{t('modelManager.launchpad.welcome')}</Heading>

View File

@@ -0,0 +1,45 @@
import { Badge, Button, Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiPlusBold } from 'react-icons/pi';
type Props = {
handleInstall: () => void;
isInstalled: boolean;
};
export const ModelResultItemActions = memo(({ handleInstall, isInstalled }: Props) => {
const { t } = useTranslation();
return (
<Flex gap={2} shrink={0} pt={1}>
{isInstalled ? (
// TODO: Add a link button to navigate to model
<Badge
variant="subtle"
colorScheme="green"
display="flex"
gap={1}
alignItems="center"
borderRadius="base"
h="24px"
>
<PiCheckBold size="14px" />
</Badge>
) : (
<Button
onClick={handleInstall}
rightIcon={<PiPlusBold size="14px" />}
textTransform="uppercase"
letterSpacing="wider"
fontSize="9px"
size="sm"
>
{t('modelManager.install')}
</Button>
)}
</Flex>
);
});
ModelResultItemActions.displayName = 'ModelResultItemActions';

View File

@@ -1,33 +1,56 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Text } from '@invoke-ai/ui-library';
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
import { memo, useCallback, useMemo } from 'react';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
type Props = {
result: ScanFolderResponse[number];
installModel: (source: string) => void;
};
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
const { t } = useTranslation();
const scanFolderResultItemSx: SystemStyleObject = {
alignItems: 'center',
justifyContent: 'space-between',
w: '100%',
py: 2,
px: 1,
gap: 3,
borderBottomWidth: '1px',
borderColor: 'base.700',
};
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
const handleInstall = useCallback(() => {
installModel(result.path);
}, [installModel, result]);
const modelDisplayName = useMemo(() => {
const normalizedPath = result.path.replace(/\\/g, '/').replace(/\/+$/, '');
// Extract filename/folder name from path
const lastSlashIndex = normalizedPath.lastIndexOf('/');
return lastSlashIndex === -1 ? normalizedPath : normalizedPath.slice(lastSlashIndex + 1);
}, [result.path]);
const modelPathParts = result.path.split(/[/\\]/);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex sx={scanFolderResultItemSx}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result.path}</Text>
{/* Model Title */}
<Text fontWeight="semibold">{modelDisplayName}</Text>
{/* Model Path */}
<Flex flexWrap="wrap" color="base.200">
{modelPathParts.map((part, index) => (
<Text key={index} variant="subtext">
{part}
{index < modelPathParts.length - 1 && '/'}
</Text>
))}
</Flex>
</Flex>
<Box>
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
)}
</Box>
<ModelResultItemActions handleInstall={handleInstall} isInstalled={result.is_installed} />
</Flex>
);
});

View File

@@ -113,9 +113,9 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
<Flex flexDir="column">
{filteredResults.map((result) => (
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
))}

View File

@@ -13,6 +13,7 @@ import { useStarterBundleInstallStatus } from 'features/modelManagerV2/hooks/use
import { t } from 'i18next';
import type { MouseEvent } from 'react';
import { useCallback } from 'react';
import { PiDownloadSimpleBold } from 'react-icons/pi';
import type { S } from 'services/api/types';
export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterModelBundle'] } & ButtonProps) => {
@@ -33,8 +34,16 @@ export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterMod
return (
<>
<Button onClick={onClickBundle} isDisabled={install.length === 0} {...rest}>
{bundle.name}
<Button
display="flex"
justifyContent="space-between"
gap={2}
onClick={onClickBundle}
isDisabled={install.length === 0}
{...rest}
>
<span>{bundle.name}</span>
<PiDownloadSimpleBold size="16px" />
</Button>
<ConfirmationAlertDialog
isOpen={isOpen}

View File

@@ -1,17 +1,30 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Text } from '@invoke-ai/ui-library';
import { negate } from 'es-toolkit/compat';
import { flattenStarterModel, useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { StarterModel } from 'services/api/types';
const starterModelResultItemSx: SystemStyleObject = {
alignItems: 'start',
justifyContent: 'space-between',
w: '100%',
py: 2,
px: 1,
gap: 2,
borderBottomWidth: '1px',
borderColor: 'base.700',
};
type Props = {
starterModel: StarterModel;
};
export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
const { t } = useTranslation();
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
@@ -40,22 +53,16 @@ export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
}, [modelsToInstall, installModel, t]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex sx={starterModelResultItemSx}>
<Flex fontSize="sm" flexDir="column">
<Flex gap={3}>
<Text fontWeight="semibold">{starterModel.name}</Text>
<Text variant="subtext">{starterModel.description}</Text>
<Flex gap={1} py={1} alignItems="center">
<Badge h="min-content">{starterModel.type.replaceAll('_', ' ')}</Badge>
<ModelBaseBadge base={starterModel.base} />
<Text fontWeight="semibold">{starterModel.name}</Text>
</Flex>
<Text variant="subtext">{starterModel.description}</Text>
</Flex>
<Box>
{isInstalled ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
)}
</Box>
<ModelResultItemActions handleInstall={onClick} isInstalled={isInstalled} />
</Flex>
);
});

View File

@@ -48,9 +48,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
return (
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex gap={3} direction="column">
{size(results.starter_bundles) > 0 && (
<Flex gap={4} alignItems="center">
<Flex gap={4} alignItems="center" justifyContent="space-between" p={4} borderWidth="1px" rounded="base">
<Flex gap={2} alignItems="center">
<Text color="base.200" fontWeight="semibold">
{t('modelManager.starterBundles')}
@@ -73,7 +73,8 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
</Flex>
</Flex>
)}
<InputGroup w={64} size="xs">
<InputGroup w="100%" size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
@@ -96,9 +97,10 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
)}
</InputGroup>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
<Flex flexDir="column">
{filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} starterModel={result} />
))}

View File

@@ -1,10 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Button, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiInfoBold } from 'react-icons/pi';
import { PiCubeBold, PiFolderOpenBold, PiInfoBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi';
import { SiHuggingface } from 'react-icons/si';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
@@ -12,6 +14,12 @@ import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
const installModelsTabSx: SystemStyleObject = {
display: 'flex',
gap: 2,
px: 2,
};
export const InstallModels = memo(() => {
const { t } = useTranslation();
const tabIndex = useStore($installModelsTabIndex);
@@ -29,21 +37,36 @@ export const InstallModels = memo(() => {
</Button>
</Flex>
<Tabs
variant="collapse"
height="50%"
variant="line"
height="100%"
display="flex"
flexDir="column"
index={tabIndex}
onChange={$installModelsTabIndex.set}
>
<TabList>
<Tab>{t('modelManager.launchpadTab')}</Tab>
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scanFolder')}</Tab>
<Tab>{t('modelManager.starterModels')}</Tab>
<Tab sx={installModelsTabSx}>
<PiCubeBold />
{t('modelManager.launchpadTab')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiLinkSimpleBold />
{t('modelManager.urlOrLocalPath')}
</Tab>
<Tab sx={installModelsTabSx}>
<SiHuggingface />
{t('modelManager.huggingFace')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiFolderOpenBold />
{t('modelManager.scanFolder')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiShootingStarBold />
{t('modelManager.starterModels')}
</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanels height="100%">
<TabPanel height="100%">
<LaunchpadForm />
</TabPanel>

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
@@ -8,6 +9,16 @@ import { PiPlusBold } from 'react-icons/pi';
import ModelList from './ModelManagerPanel/ModelList';
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
const modelManagerSx: SystemStyleObject = {
flexDir: 'column',
p: 4,
gap: 4,
borderRadius: 'base',
w: '50%',
minWidth: '360px',
h: 'full',
};
export const ModelManager = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -17,7 +28,7 @@ export const ModelManager = memo(() => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
return (
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
<Flex sx={modelManagerSx}>
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
<Heading fontSize="xl" py={1}>
{t('common.modelManager')}
@@ -28,7 +39,7 @@ export const ModelManager = memo(() => {
</Button>
)}
</Flex>
<Flex flexDir="column" layerStyle="second" p={4} gap={4} borderRadius="base" w="full" h="full">
<Flex flexDir="column" gap={4} w="full" h="full">
<ModelListNavigation />
<ModelList />
</Flex>

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { PiImage } from 'react-icons/pi';
@@ -6,19 +7,23 @@ type Props = {
image_url?: string | null;
};
export const MODEL_IMAGE_THUMBNAIL_SIZE = '40px';
const FALLBACK_ICON_SIZE = '24px';
const MODEL_IMAGE_THUMBNAIL_SIZE = '54px';
const FALLBACK_ICON_SIZE = '28px';
const sharedSx: SystemStyleObject = {
rounded: 'base',
height: MODEL_IMAGE_THUMBNAIL_SIZE,
minWidth: MODEL_IMAGE_THUMBNAIL_SIZE,
bg: 'base.850',
borderWidth: '1px',
borderColor: 'base.750',
borderStyle: 'solid',
};
const ModelImage = ({ image_url }: Props) => {
if (!image_url) {
return (
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
alignItems="center"
justifyContent="center"
>
<Flex alignItems="center" justifyContent="center" sx={sharedSx}>
<Icon color="base.500" as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
);
@@ -29,16 +34,14 @@ const ModelImage = ({ image_url }: Props) => {
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height={MODEL_IMAGE_THUMBNAIL_SIZE}
width={MODEL_IMAGE_THUMBNAIL_SIZE}
minHeight={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
sx={sharedSx}
fallback={
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
sx={sharedSx}
alignItems="center"
justifyContent="center"
>

View File

@@ -1,32 +1,57 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
import { Flex, Spacer, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
import { toast } from 'features/toast/toast';
import { ModelDeleteButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton';
import { filesize } from 'filesize';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import ModelImage, { MODEL_IMAGE_THUMBNAIL_SIZE } from './ModelImage';
import ModelImage from './ModelImage';
type ModelListItemProps = {
model: AnyModelConfig;
};
const sx: SystemStyleObject = {
_hover: { bg: 'base.700' },
"&[aria-selected='true']": { bg: 'base.700' },
paddingInline: 3,
paddingBlock: 2,
position: 'relative',
rounded: 'base',
'&:after,&:before': {
content: `''`,
position: 'absolute',
pointerEvents: 'none',
},
'&:after': {
h: '1px',
bottom: '-0.5px',
insetInline: 3,
bg: 'base.850',
},
'&:before': {
left: 1,
w: 1,
insetBlock: 2,
rounded: 'base',
},
_hover: {
bg: 'base.850',
'& .delete-button': { opacity: 1 },
},
'& .delete-button': { opacity: 0 },
"&[aria-selected='false']:hover:before": { bg: 'base.750' },
"&[aria-selected='true']": {
bg: 'base.800',
'& .delete-button': { opacity: 1 },
},
"&[aria-selected='true']:before": { bg: 'invokeBlue.300' },
};
const ModelListItem = ({ model }: ModelListItemProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectIsSelected = useMemo(
() =>
@@ -37,58 +62,25 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
[model.key]
);
const isSelected = useAppSelector(selectIsSelected);
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const handleSelectModel = useCallback(() => {
dispatch(setSelectedModelKey(model.key));
}, [model.key, dispatch]);
const onClickDeleteButton = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
onOpen();
},
[onOpen]
);
const handleModelDelete = useCallback(() => {
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
toast({
id: 'MODEL_DELETED',
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: 'MODEL_DELETE_FAILED',
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
});
}
});
dispatch(setSelectedModelKey(null));
}, [deleteModel, model, dispatch, t]);
return (
<Flex
sx={sx}
aria-selected={isSelected}
justifyContent="flex-start"
p={2}
borderRadius="base"
w="full"
alignItems="center"
alignItems="flex-start"
gap={2}
cursor="pointer"
onClick={handleSelectModel}
>
<Flex gap={2} w="full" h="full" minW={0}>
<ModelImage image_url={model.cover_image} />
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
<Flex alignItems="flex-start" flexDir="column" w="full" minW={0}>
<Flex gap={2} w="full" alignItems="flex-start">
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
{model.name}
@@ -101,39 +93,15 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
<Text variant="subtext" noOfLines={1}>
{model.description || 'No Description'}
</Text>
</Flex>
<Flex
h={MODEL_IMAGE_THUMBNAIL_SIZE}
flexDir="column"
alignItems="flex-end"
justifyContent="space-between"
gap={2}
>
<ModelBaseBadge base={model.base} />
<ModelFormatBadge format={model.format} />
<Flex gap={1} mt={1}>
<ModelBaseBadge base={model.base} />
<ModelFormatBadge format={model.format} />
</Flex>
</Flex>
</Flex>
<IconButton
onClick={onClickDeleteButton}
icon={<PiTrashSimpleBold size={16} />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
h={MODEL_IMAGE_THUMBNAIL_SIZE}
w={MODEL_IMAGE_THUMBNAIL_SIZE}
/>
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
useInert={false}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
<Flex mt={1}>
<ModelDeleteButton modelConfig={model} showLabel={false} />
</Flex>
</Flex>
);
};

View File

@@ -1,4 +1,4 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSearchTerm, setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
@@ -25,9 +25,7 @@ export const ModelListNavigation = memo(() => {
return (
<Flex gap={2} alignItems="center" justifyContent="space-between">
<ModelTypeFilter />
<Spacer />
<InputGroup maxW="400px">
<InputGroup>
<Input
placeholder={t('modelManager.search')}
value={searchTerm || ''}
@@ -47,6 +45,9 @@ export const ModelListNavigation = memo(() => {
</InputRightElement>
)}
</InputGroup>
<Flex shrink={0}>
<ModelTypeFilter />
</Flex>
</Flex>
);
});

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { StickyScrollable } from 'features/system/components/StickyScrollable';
import { memo } from 'react';
import type { AnyModelConfig } from 'services/api/types';
@@ -9,10 +10,23 @@ type ModelListWrapperProps = {
modelList: AnyModelConfig[];
};
const headingSx = {
bg: 'base.900',
pb: 3,
pl: 3,
} satisfies SystemStyleObject;
const contentSx = {
gap: 0,
p: 0,
bg: 'base.900',
borderRadius: '0',
} satisfies SystemStyleObject;
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
const { title, modelList } = props;
return (
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
<StickyScrollable title={title} contentSx={contentSx} headingSx={headingSx}>
{modelList.map((model) => (
<ModelListItem key={model.key} model={model} />
))}

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
@@ -6,13 +7,23 @@ import { memo } from 'react';
import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model';
const modelPaneSx: SystemStyleObject = {
layerStyle: 'first',
p: 4,
borderRadius: 'base',
w: {
base: '50%',
lg: '75%',
'2xl': '85%',
},
h: 'full',
minWidth: '300px',
overflow: 'auto',
};
export const ModelPane = memo(() => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
return (
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box>
);
return <Box sx={modelPaneSx}>{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}</Box>;
});
ModelPane.displayName = 'ModelPane';

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { dropzoneAccept } from 'common/hooks/useImageUploadButton';
import { typedMemo } from 'common/util/typedMemo';
@@ -8,6 +9,21 @@ import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUploadBold } from 'react-icons/pi';
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
const sharedSx: SystemStyleObject = {
w: 108,
h: 108,
fontSize: 36,
borderRadius: 'base',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
bg: 'base.800',
borderWidth: '1px',
borderStyle: 'solid',
borderColor: 'base.700',
flexShrink: 0,
};
type Props = {
model_key: string | null;
model_image?: string | null;
@@ -86,10 +102,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
src={image}
objectFit="cover"
objectPosition="50% 50%"
height={108}
width={108}
minWidth={108}
borderRadius="base"
sx={sharedSx}
/>
<IconButton
position="absolute"
@@ -112,10 +127,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
variant="ghost"
aria-label={t('modelManager.uploadImage')}
tooltip={t('modelManager.uploadImage')}
w={108}
h={108}
fontSize={36}
icon={<PiUploadBold />}
sx={sharedSx}
isLoading={request.isLoading}
{...getRootProps()}
/>

View File

@@ -52,6 +52,7 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
return (
<>
<Button
variant="outline"
onClick={onOpen}
size="sm"
aria-label={t('modelManager.convertToDiffusers')}

View File

@@ -0,0 +1,95 @@
import { Button, ConfirmationAlertDialog, Flex, IconButton, Text, useDisclosure } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppDispatch } from 'app/store/storeHooks';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { toast } from 'features/toast/toast';
import { memo, type MouseEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
showLabel?: boolean;
modelConfig: AnyModelConfig;
};
export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const log = logger('models');
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const onClickDeleteButton = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
onOpen();
},
[onOpen]
);
const handleModelDelete = useCallback(() => {
deleteModel({ key: modelConfig.key })
.unwrap()
.then(() => {
dispatch(setSelectedModelKey(null));
toast({
id: 'MODEL_DELETED',
title: `${t('modelManager.modelDeleted')}: ${modelConfig.name}`,
status: 'success',
});
})
.catch((error) => {
log.error('Error deleting model', error);
toast({
id: 'MODEL_DELETE_FAILED',
title: `${t('modelManager.modelDeleteFailed')}: ${modelConfig.name}`,
status: 'error',
});
});
}, [deleteModel, modelConfig.key, modelConfig.name, dispatch, t, log]);
return (
<>
{showLabel ? (
<Button
className="delete-button"
size="sm"
leftIcon={<PiTrashSimpleBold />}
colorScheme="error"
onClick={onClickDeleteButton}
flexShrink={0}
>
{t('modelManager.delete')}
</Button>
) : (
<IconButton
className="delete-button"
onClick={onClickDeleteButton}
icon={<PiTrashSimpleBold size={16} />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
/>
)}
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
useInert={false}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
</>
);
});
ModelDeleteButton.displayName = 'ModelDeleteButton';

View File

@@ -24,6 +24,7 @@ import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import { ModelFooter } from './ModelFooter';
type Props = {
modelConfig: AnyModelConfig;
@@ -158,6 +159,7 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
</Flex>
</form>
</Flex>
<ModelFooter modelConfig={modelConfig} isEditing={true} />
</Flex>
);
});

View File

@@ -0,0 +1,66 @@
import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
import { ModelConvertButton } from './ModelConvertButton';
import { ModelDeleteButton } from './ModelDeleteButton';
import { ModelEditButton } from './ModelEditButton';
const footerRowSx: SystemStyleObject = {
justifyContent: 'space-between',
alignItems: 'center',
gap: 3,
'&:not(:last-of-type)': {
borderBottomWidth: '1px',
borderBottomStyle: 'solid',
borderBottomColor: 'border',
},
p: 3,
};
type Props = {
modelConfig: AnyModelConfig;
isEditing: boolean;
};
export const ModelFooter = memo(({ modelConfig, isEditing }: Props) => {
const { t } = useTranslation();
const shouldShowConvertOption = !isEditing && modelConfig.format === 'checkpoint' && modelConfig.type === 'main';
return (
<Flex flexDirection="column" borderWidth="1px" borderRadius="base">
{shouldShowConvertOption && (
<Flex sx={footerRowSx}>
<Heading size="sm" color="base.100">
{t('modelManager.convertToDiffusers')}
</Heading>
<Flex py={1}>
<ModelConvertButton modelConfig={modelConfig} />
</Flex>
</Flex>
)}
{!isEditing && (
<Flex sx={footerRowSx}>
<Heading size="sm" color="base.100">
{t('modelManager.edit')}
</Heading>
<Flex py={1}>
<ModelEditButton />
</Flex>
</Flex>
)}
<Flex sx={footerRowSx}>
<Heading size="sm" color="error.200">
{t('modelManager.deleteModel')}
</Heading>
<Flex py={1}>
<ModelDeleteButton modelConfig={modelConfig} />
</Flex>
</Flex>
</Flex>
);
});
ModelFooter.displayName = 'ModelFooter';

View File

@@ -1,4 +1,4 @@
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import { LoRAModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/LoRAModelDefaultSettings';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
@@ -12,6 +12,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { ModelFooter } from './ModelFooter';
import { RelatedModels } from './RelatedModels';
type Props = {
@@ -39,15 +40,16 @@ export const ModelView = memo(({ modelConfig }: Props) => {
}, [modelConfig.base, modelConfig.type]);
return (
<Flex flexDir="column" gap={4}>
<Flex flexDir="column" gap={4} h="full">
<ModelHeader modelConfig={modelConfig}>
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
<ModelConvertButton modelConfig={modelConfig} />
)}
<ModelEditButton />
</ModelHeader>
<Flex flexDir="column" h="full" gap={4}>
<Box layerStyle="second" borderRadius="base" p={4}>
<Divider />
<Flex flexDir="column" gap={4}>
<Box>
<SimpleGrid columns={2} gap={4}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
@@ -73,26 +75,33 @@ export const ModelView = memo(({ modelConfig }: Props) => {
</SimpleGrid>
</Box>
{withSettings && (
<Box layerStyle="second" borderRadius="base" p={4}>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
{modelConfig.type === 'lora' && (
<>
<LoRAModelDefaultSettings modelConfig={modelConfig} />
<TriggerPhrases modelConfig={modelConfig} />
</>
)}
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
<>
<Divider />
<Box>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && (
<ControlAdapterModelDefaultSettings modelConfig={modelConfig} />
)}
{modelConfig.type === 'lora' && (
<>
<LoRAModelDefaultSettings modelConfig={modelConfig} />
<TriggerPhrases modelConfig={modelConfig} />
</>
)}
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
</>
)}
<Box overflowY="auto" layerStyle="second" borderRadius="base" p={4}>
<Divider />
<Box overflowY="auto">
<RelatedModels modelConfig={modelConfig} />
</Box>
</Flex>
<ModelFooter modelConfig={modelConfig} isEditing={false} />
</Flex>
);
});

View File

@@ -74,6 +74,8 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
[addTriggerPhrase]
);
const hasTriggerPhrases = triggerPhrases.length > 0;
return (
<Flex flexDir="column" w="full" gap="5">
<form onSubmit={onTriggerPhraseAddFormSubmit}>
@@ -99,14 +101,16 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index} py={2} px={4} bg="base.700">
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
{hasTriggerPhrases && (
<Flex gap="4" flexWrap="wrap">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index} py={2} px={4} bg="base.700">
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
)}
</Flex>
);
});

View File

@@ -1,14 +1,10 @@
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -27,22 +23,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isCogView4MainModelFieldInputInstance,
isCogView4MainModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
isControlLoRAModelFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
@@ -51,68 +33,28 @@ import {
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxKontextModelFieldInputInstance,
isFluxKontextModelFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxReduxModelFieldInputInstance,
isFluxReduxModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isImageGeneratorFieldInputInstance,
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isImagen4ModelFieldInputInstance,
isImagen4ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIntegerGeneratorFieldInputInstance,
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isModelIdentifierFieldInputInstance,
isModelIdentifierFieldInputTemplate,
isRunwayModelFieldInputInstance,
isRunwayModelFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
isSigLipModelFieldInputInstance,
isSigLipModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStringGeneratorFieldInputInstance,
isStringGeneratorFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
isT5EncoderModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
isVeo3ModelFieldInputInstance,
isVeo3ModelFieldInputTemplate,
} from 'features/nodes/types/field';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo } from 'react';
@@ -121,33 +63,10 @@ import { assert } from 'tsafe';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import RunwayModelFieldInputComponent from './inputs/RunwayModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
import Veo3ModelFieldInputComponent from './inputs/Veo3ModelFieldInputComponent';
type Props = {
nodeId: string;
@@ -287,13 +206,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isMainModelFieldInputTemplate(template)) {
if (!isMainModelFieldInputInstance(field)) {
return null;
}
return <MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isModelIdentifierFieldInputTemplate(template)) {
if (!isModelIdentifierFieldInputInstance(field)) {
return null;
@@ -301,159 +213,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSDXLRefinerModelFieldInputTemplate(template)) {
if (!isSDXLRefinerModelFieldInputInstance(field)) {
return null;
}
return <RefinerModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isVAEModelFieldInputTemplate(template)) {
if (!isVAEModelFieldInputInstance(field)) {
return null;
}
return <VAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isT5EncoderModelFieldInputTemplate(template)) {
if (!isT5EncoderModelFieldInputInstance(field)) {
return null;
}
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPEmbedModelFieldInputTemplate(template)) {
if (!isCLIPEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPLEmbedModelFieldInputTemplate(template)) {
if (!isCLIPLEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPGEmbedModelFieldInputTemplate(template)) {
if (!isCLIPGEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isControlLoRAModelFieldInputTemplate(template)) {
if (!isControlLoRAModelFieldInputInstance(field)) {
return null;
}
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;
}
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLoRAModelFieldInputTemplate(template)) {
if (!isLoRAModelFieldInputInstance(field)) {
return null;
}
return <LoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isControlNetModelFieldInputTemplate(template)) {
if (!isControlNetModelFieldInputInstance(field)) {
return null;
}
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isIPAdapterModelFieldInputTemplate(template)) {
if (!isIPAdapterModelFieldInputInstance(field)) {
return null;
}
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isT2IAdapterModelFieldInputTemplate(template)) {
if (!isT2IAdapterModelFieldInputInstance(field)) {
return null;
}
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSpandrelImageToImageModelFieldInputTemplate(template)) {
if (!isSpandrelImageToImageModelFieldInputInstance(field)) {
return null;
}
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSigLipModelFieldInputTemplate(template)) {
if (!isSigLipModelFieldInputInstance(field)) {
return null;
}
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxReduxModelFieldInputTemplate(template)) {
if (!isFluxReduxModelFieldInputInstance(field)) {
return null;
}
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen3ModelFieldInputTemplate(template)) {
if (!isImagen3ModelFieldInputInstance(field)) {
return null;
}
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen4ModelFieldInputTemplate(template)) {
if (!isImagen4ModelFieldInputInstance(field)) {
return null;
}
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxKontextModelFieldInputTemplate(template)) {
if (!isFluxKontextModelFieldInputInstance(field)) {
return null;
}
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
}
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isVeo3ModelFieldInputTemplate(template)) {
if (!isVeo3ModelFieldInputInstance(field)) {
return null;
}
return <Veo3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isRunwayModelFieldInputTemplate(template)) {
if (!isRunwayModelFieldInputInstance(field)) {
return null;
}
return <RunwayModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isColorFieldInputTemplate(template)) {
if (!isColorFieldInputInstance(field)) {
return null;
@@ -461,34 +220,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ColorFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxMainModelFieldInputTemplate(template)) {
if (!isFluxMainModelFieldInputInstance(field)) {
return null;
}
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;
}
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCogView4MainModelFieldInputTemplate(template)) {
if (!isCogView4MainModelFieldInputInstance(field)) {
return null;
}
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSDXLMainModelFieldInputTemplate(template)) {
if (!isSDXLMainModelFieldInputInstance(field)) {
return null;
}
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSchedulerFieldInputTemplate(template)) {
if (!isSchedulerFieldInputInstance(field)) {
return null;

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPEmbedModelFieldInputComponent);

View File

@@ -1,45 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPGEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPGEmbedModelFieldInputComponent);

View File

@@ -1,45 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPLEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPLEmbedModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const ChatGPT4oModelFieldInputComponent = (
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldChatGPT4oModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ChatGPT4oModelFieldInputComponent);

View File

@@ -1,63 +0,0 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
CogView4MainModelFieldInputInstance,
CogView4MainModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCogView4Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
const CogView4MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCogView4Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(CogView4MainModelFieldInputComponent);

View File

@@ -1,48 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
ControlLoRAModelFieldInputInstance,
ControlLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
import type { ControlLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
const ControlLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlLoRAModel();
const onChange = useCallback(
(value: ControlLoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldControlLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ControlLoRAModelFieldInputComponent);

View File

@@ -1,45 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate>;
const ControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
const onChange = useCallback(
(value: ControlNetModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldControlNetModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ControlNetModelFieldInputComponent);

View File

@@ -1,49 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FluxKontextModelFieldInputInstance,
FluxKontextModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxKontextModelFieldInputComponent = (
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxKontextModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxKontextModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxKontextModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
import type { FLUXReduxModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxReduxModelFieldInputComponent = (
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const onChange = useCallback(
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxReduxModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxReduxModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxVAEModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxVAEModelFieldInputComponent);

View File

@@ -1,45 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const IPAdapterModelFieldInputComponent = (
props: FieldComponentProps<IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldIPAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(IPAdapterModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen3Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen3ModelFieldInputComponent = (
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen3Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen3ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen3ModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen4Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen4ModelFieldInputComponent = (
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen4Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen4ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen4ModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
import type { LlavaOnevisionConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLLaVAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(LLaVAModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate>;
const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(LoRAModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInputTemplate>;
const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(MainModelFieldInputComponent);

View File

@@ -12,7 +12,7 @@ import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const onChange = useCallback(
@@ -36,8 +36,31 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
if (!fieldTemplate.ui_model_base && !fieldTemplate.ui_model_type) {
return modelConfigsAdapterSelectors.selectAll(data);
}
return modelConfigsAdapterSelectors.selectAll(data).filter((config) => {
if (fieldTemplate.ui_model_base && !fieldTemplate.ui_model_base.includes(config.base)) {
return false;
}
if (fieldTemplate.ui_model_type && !fieldTemplate.ui_model_type.includes(config.type)) {
return false;
}
if (
fieldTemplate.ui_model_variant &&
'variant' in config &&
config.variant &&
!fieldTemplate.ui_model_variant.includes(config.variant)
) {
return false;
}
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
return false;
}
return true;
});
}, [data, fieldTemplate]);
return (
<ModelFieldCombobox

View File

@@ -1,47 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefinerModelFieldInputTemplate>;
const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRefinerModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldRefinerModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(RefinerModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRunwayModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useRunwayModels } from 'services/api/hooks/modelsByType';
import type { VideoApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const RunwayModelFieldInputComponent = (
props: FieldComponentProps<RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRunwayModels();
const onChange = useCallback(
(value: VideoApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldRunwayModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(RunwayModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate>;
const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSDXLModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SDXLMainModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSigLipModels } from 'services/api/hooks/modelsByType';
import type { SigLipModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SigLipModelFieldInputComponent = (
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSigLipModels();
const onChange = useCallback(
(value: SigLipModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSigLipModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SigLipModelFieldInputComponent);

View File

@@ -1,49 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SpandrelImageToImageModelFieldInputComponent = (
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSpandrelImageToImageModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SpandrelImageToImageModelFieldInputComponent);

View File

@@ -1,46 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const T2IAdapterModelFieldInputComponent = (
props: FieldComponentProps<T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT2IAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(T2IAdapterModelFieldInputComponent);

View File

@@ -1,43 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT5EncoderValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(T5EncoderModelFieldInputComponent);

View File

@@ -1,44 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputTemplate>;
const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldVaeModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(VAEModelFieldInputComponent);

Some files were not shown because too many files have changed in this diff Show More