mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 01:58:14 -05:00
Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3707c3b034 | ||
|
|
5885db4ab5 | ||
|
|
36ed9b750d | ||
|
|
3cec06f86e | ||
|
|
28b5f7a1c5 | ||
|
|
22cbb23ae0 | ||
|
|
4d585e3eec | ||
|
|
006b4356bb | ||
|
|
da947866f2 | ||
|
|
84a2cc6fc9 | ||
|
|
b50534bb49 | ||
|
|
c305e79fee | ||
|
|
c32949d113 | ||
|
|
87a98902da | ||
|
|
2857a446c9 | ||
|
|
035d9432bd | ||
|
|
bdeb9fb1cf | ||
|
|
dadff57061 | ||
|
|
480857ae4e | ||
|
|
eaf0624004 | ||
|
|
58bca1b9f4 | ||
|
|
54aa6908fa | ||
|
|
e6d9daca96 | ||
|
|
6e5a529cb7 | ||
|
|
8c742a6e38 | ||
|
|
693373f1c1 | ||
|
|
4809080fd9 | ||
|
|
efcb1bea7f | ||
|
|
e0d7a401f3 | ||
|
|
aac979e9a4 | ||
|
|
3b0d7f076d | ||
|
|
e1acbcdbd5 | ||
|
|
7d9b81550b | ||
|
|
6a447dd1fe | ||
|
|
c2dc63ddbc | ||
|
|
1bc689d531 | ||
|
|
4829975827 | ||
|
|
49da4e00c3 | ||
|
|
89dfe5e729 | ||
|
|
6816d366df | ||
|
|
9d3d2a36c9 | ||
|
|
ed231044c8 | ||
|
|
b51a232794 | ||
|
|
4412143a6e | ||
|
|
de11cafdb3 | ||
|
|
4d9114aa7d | ||
|
|
67e2da1ebf | ||
|
|
33ecc591c3 | ||
|
|
b57459a226 | ||
|
|
01282b1c90 | ||
|
|
3f302906dc | ||
|
|
81d56596fb | ||
|
|
b536b0df0c | ||
|
|
692af1d93d | ||
|
|
bb7ef77b50 | ||
|
|
1862548573 | ||
|
|
242c1b6350 | ||
|
|
fc6e4bb04e | ||
|
|
20841abca6 | ||
|
|
e8b69d99a4 | ||
|
|
d6eaff8237 | ||
|
|
068b095956 | ||
|
|
f795a47340 | ||
|
|
df47345eb0 | ||
|
|
def04095a4 | ||
|
|
28be8f0911 | ||
|
|
b50c44bac0 | ||
|
|
b4ce0e02fc | ||
|
|
d6442d9a34 | ||
|
|
4528bcafaf | ||
|
|
8b82b81ee2 | ||
|
|
757acdd49e | ||
|
|
94b7cc583a |
@@ -36,6 +36,9 @@ from pydantic_core import PydanticUndefined
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
InputFieldJSONSchemaExtra,
|
||||
UIType,
|
||||
migrate_model_ui_type,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -256,7 +259,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
@@ -445,6 +450,15 @@ with warnings.catch_warnings():
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
|
||||
|
||||
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
|
||||
"""Checks if a value is a member of an enum class."""
|
||||
try:
|
||||
enum_class(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
|
||||
|
||||
if not field.annotation:
|
||||
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
|
||||
|
||||
if not isinstance(field.json_schema_extra, dict):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
|
||||
)
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
|
||||
|
||||
field_kind = field.json_schema_extra.get("field_kind", None)
|
||||
|
||||
# must have a field_kind
|
||||
if not isinstance(field_kind, FieldKind):
|
||||
if not is_enum_member(field_kind, FieldKind):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
|
||||
)
|
||||
|
||||
if field_kind is FieldKind.Input and (
|
||||
if field_kind == FieldKind.Input.value and (
|
||||
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
|
||||
|
||||
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
|
||||
|
||||
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
|
||||
|
||||
# node attribute fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind is FieldKind.NodeAttribute
|
||||
field_kind == FieldKind.NodeAttribute.value
|
||||
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
|
||||
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
|
||||
)
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
|
||||
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
|
||||
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
|
||||
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
|
||||
|
||||
if ui_type is not None:
|
||||
# There are 3 cases where we may need to take action:
|
||||
#
|
||||
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
|
||||
# deprecated and should be migrated to:
|
||||
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
|
||||
# - ui_model_type=[ModelType.Main]
|
||||
#
|
||||
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
|
||||
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
|
||||
#
|
||||
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
|
||||
# Image fields are now automatically detected based on the field's type annotation. In this case, we
|
||||
# ignore ui_type and log a warning.
|
||||
#
|
||||
# The cases must be checked in this order to ensure proper handling.
|
||||
|
||||
# Easier to work with as an enum
|
||||
ui_type = UIType(ui_type)
|
||||
|
||||
# The enum member values are not always the same as their names - we want to log the name so the user can
|
||||
# easily review their code and see where the deprecated enum member is used.
|
||||
human_readable_name = f"UIType.{ui_type.name}"
|
||||
|
||||
# Case 1: migratable deprecated value
|
||||
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
|
||||
|
||||
if did_migrate:
|
||||
logger.warning(
|
||||
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
|
||||
)
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
# Case 2: mutually exclusive with new fields
|
||||
elif (
|
||||
ui_model_base is not None
|
||||
or ui_model_type is not None
|
||||
or ui_model_variant is not None
|
||||
or ui_model_format is not None
|
||||
):
|
||||
logger.warning(
|
||||
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
|
||||
)
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
# Case 3: deprecated value that is not migratable
|
||||
elif ui_type.startswith("DEPRECATED_"):
|
||||
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
GlmEncoderField,
|
||||
ModelIdentifierField,
|
||||
@@ -14,6 +14,7 @@ from invokeai.app.invocations.model import (
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("cogview4_model_loader_output")
|
||||
@@ -38,8 +39,9 @@ class CogView4ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.cogview4_model,
|
||||
ui_type=UIType.CogView4MainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.CogView4,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
|
||||
|
||||
@@ -16,7 +16,6 @@ from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
@@ -28,6 +27,7 @@ from invokeai.app.util.controlnet_utils import (
|
||||
heuristic_resize_fast,
|
||||
)
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -63,13 +63,17 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
@invocation(
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
||||
)
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -7,6 +7,13 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@@ -39,47 +46,15 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
CogView4MainModel = "CogView4MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
Imagen4Model = "Imagen4ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
Gemini2_5Model = "Gemini2_5ModelField"
|
||||
FluxKontextModel = "FluxKontextModelField"
|
||||
Veo3Model = "Veo3ModelField"
|
||||
RunwayModel = "RunwayModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
Video = "VideoField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
_IsIntermediate = "IsIntermediate"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
@@ -117,13 +92,44 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
|
||||
# Deprecated Model Field Types - use ui_model_[base|type|variant|format] instead
|
||||
MainModel = "DEPRECATED_MainModelField"
|
||||
CogView4MainModel = "DEPRECATED_CogView4MainModelField"
|
||||
FluxMainModel = "DEPRECATED_FluxMainModelField"
|
||||
SD3MainModel = "DEPRECATED_SD3MainModelField"
|
||||
SDXLMainModel = "DEPRECATED_SDXLMainModelField"
|
||||
SDXLRefinerModel = "DEPRECATED_SDXLRefinerModelField"
|
||||
ONNXModel = "DEPRECATED_ONNXModelField"
|
||||
VAEModel = "DEPRECATED_VAEModelField"
|
||||
FluxVAEModel = "DEPRECATED_FluxVAEModelField"
|
||||
LoRAModel = "DEPRECATED_LoRAModelField"
|
||||
ControlNetModel = "DEPRECATED_ControlNetModelField"
|
||||
IPAdapterModel = "DEPRECATED_IPAdapterModelField"
|
||||
T2IAdapterModel = "DEPRECATED_T2IAdapterModelField"
|
||||
T5EncoderModel = "DEPRECATED_T5EncoderModelField"
|
||||
CLIPEmbedModel = "DEPRECATED_CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "DEPRECATED_CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "DEPRECATED_CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "DEPRECATED_SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "DEPRECATED_ControlLoRAModelField"
|
||||
SigLipModel = "DEPRECATED_SigLipModelField"
|
||||
FluxReduxModel = "DEPRECATED_FluxReduxModelField"
|
||||
LlavaOnevisionModel = "DEPRECATED_LLaVAModelField"
|
||||
Imagen3Model = "DEPRECATED_Imagen3ModelField"
|
||||
Imagen4Model = "DEPRECATED_Imagen4ModelField"
|
||||
ChatGPT4oModel = "DEPRECATED_ChatGPT4oModelField"
|
||||
Gemini2_5Model = "DEPRECATED_Gemini2_5ModelField"
|
||||
FluxKontextModel = "DEPRECATED_FluxKontextModelField"
|
||||
Veo3Model = "DEPRECATED_Veo3ModelField"
|
||||
RunwayModel = "DEPRECATED_RunwayModelField"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
@@ -409,10 +415,15 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
ui_model_base: Optional[list[BaseModelType]] = None
|
||||
ui_model_type: Optional[list[ModelType]] = None
|
||||
ui_model_variant: Optional[list[ClipVariantType | ModelVariantType]] = None
|
||||
ui_model_format: Optional[list[ModelFormat]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
use_enum_values=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -465,16 +476,121 @@ class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
ui_hidden: bool = False
|
||||
ui_order: Optional[int] = None
|
||||
ui_type: Optional[UIType] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
use_enum_values=True,
|
||||
)
|
||||
|
||||
|
||||
def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, Any]) -> bool:
|
||||
"""Migrate deprecated model-specifier ui_type values to new-style ui_model_[base|type|variant|format] in json_schema_extra."""
|
||||
if not isinstance(ui_type, UIType):
|
||||
ui_type = UIType(ui_type)
|
||||
|
||||
ui_model_type: list[ModelType] | None = None
|
||||
ui_model_base: list[BaseModelType] | None = None
|
||||
ui_model_format: list[ModelFormat] | None = None
|
||||
ui_model_variant: list[ClipVariantType | ModelVariantType] | None = None
|
||||
|
||||
match ui_type:
|
||||
case UIType.MainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.CogView4MainModel:
|
||||
ui_model_base = [BaseModelType.CogView4]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.FluxMainModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SD3MainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusion3]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SDXLMainModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusionXL]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.SDXLRefinerModel:
|
||||
ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.VAEModel:
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.FluxVAEModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.LoRAModel:
|
||||
ui_model_type = [ModelType.LoRA]
|
||||
case UIType.ControlNetModel:
|
||||
ui_model_type = [ModelType.ControlNet]
|
||||
case UIType.IPAdapterModel:
|
||||
ui_model_type = [ModelType.IPAdapter]
|
||||
case UIType.T2IAdapterModel:
|
||||
ui_model_type = [ModelType.T2IAdapter]
|
||||
case UIType.T5EncoderModel:
|
||||
ui_model_type = [ModelType.T5Encoder]
|
||||
case UIType.CLIPEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
case UIType.CLIPLEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
ui_model_variant = [ClipVariantType.L]
|
||||
case UIType.CLIPGEmbedModel:
|
||||
ui_model_type = [ModelType.CLIPEmbed]
|
||||
ui_model_variant = [ClipVariantType.G]
|
||||
case UIType.SpandrelImageToImageModel:
|
||||
ui_model_type = [ModelType.SpandrelImageToImage]
|
||||
case UIType.ControlLoRAModel:
|
||||
ui_model_type = [ModelType.ControlLoRa]
|
||||
case UIType.SigLipModel:
|
||||
ui_model_type = [ModelType.SigLIP]
|
||||
case UIType.FluxReduxModel:
|
||||
ui_model_type = [ModelType.FluxRedux]
|
||||
case UIType.LlavaOnevisionModel:
|
||||
ui_model_type = [ModelType.LlavaOnevision]
|
||||
case UIType.Imagen3Model:
|
||||
ui_model_base = [BaseModelType.Imagen3]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Imagen4Model:
|
||||
ui_model_base = [BaseModelType.Imagen4]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.ChatGPT4oModel:
|
||||
ui_model_base = [BaseModelType.ChatGPT4o]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Gemini2_5Model:
|
||||
ui_model_base = [BaseModelType.Gemini2_5]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.FluxKontextModel:
|
||||
ui_model_base = [BaseModelType.FluxKontext]
|
||||
ui_model_type = [ModelType.Main]
|
||||
case UIType.Veo3Model:
|
||||
ui_model_base = [BaseModelType.Veo3]
|
||||
ui_model_type = [ModelType.Video]
|
||||
case UIType.RunwayModel:
|
||||
ui_model_base = [BaseModelType.Runway]
|
||||
ui_model_type = [ModelType.Video]
|
||||
case _:
|
||||
pass
|
||||
|
||||
did_migrate = False
|
||||
|
||||
if ui_model_type is not None:
|
||||
json_schema_extra["ui_model_type"] = [m.value for m in ui_model_type]
|
||||
did_migrate = True
|
||||
if ui_model_base is not None:
|
||||
json_schema_extra["ui_model_base"] = [m.value for m in ui_model_base]
|
||||
did_migrate = True
|
||||
if ui_model_format is not None:
|
||||
json_schema_extra["ui_model_format"] = [m.value for m in ui_model_format]
|
||||
did_migrate = True
|
||||
if ui_model_variant is not None:
|
||||
json_schema_extra["ui_model_variant"] = [m.value for m in ui_model_variant]
|
||||
did_migrate = True
|
||||
|
||||
return did_migrate
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
@@ -501,35 +617,63 @@ def InputField(
|
||||
ui_hidden: Optional[bool] = None,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
ui_model_base: Optional[BaseModelType | list[BaseModelType]] = None,
|
||||
ui_model_type: Optional[ModelType | list[ModelType]] = None,
|
||||
ui_model_variant: Optional[ClipVariantType | ModelVariantType | list[ClipVariantType | ModelVariantType]] = None,
|
||||
ui_model_format: Optional[ModelFormat | list[ModelFormat]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field)
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
If the field is a `ModelIdentifierField`, use the `ui_model_[base|type|variant|format]` args to filter the model list
|
||||
in the Workflow Editor. Otherwise, use `ui_type` to provide extra type hints for the UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
Don't use both `ui_type` and `ui_model_[base|type|variant|format]` - if both are provided, a warning will be
|
||||
logged and `ui_type` will be ignored.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
Args:
|
||||
input: The kind of input this field requires.
|
||||
- `Input.Direct` means a value must be provided on instantiation.
|
||||
- `Input.Connection` means the value must be provided by a connection.
|
||||
- `Input.Any` means either will do.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
ui_component: Optionally specifies a specific component to use in the UI. The UI will always render a suitable
|
||||
component, but sometimes you want something different than the default. For example, a `string` field will
|
||||
default to a single-line input, but you may want a multi-line textarea instead. In this case, you could use
|
||||
`UIComponent.Textarea`.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
|
||||
ui_model_base: Specifies the base model architectures to filter the model list by in the Workflow Editor. For
|
||||
example, `ui_model_base=BaseModelType.StableDiffusionXL` will show only SDXL architecture models. This arg is
|
||||
only valid if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_type: Specifies the model type(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_type=ModelType.VAE` will show only VAE models. This arg is only valid if this Input field is
|
||||
annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_variant: Specifies the model variant(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_variant=ModelVariantType.Inpainting` will show only inpainting models. This arg is only valid if this
|
||||
Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_format: Specifies the model format(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_format=ModelFormat.Diffusers` will show only models in the diffusers format. This arg is only valid
|
||||
if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_choice_labels: Specifies the labels to use for the choices in an enum field. If omitted, the enum values
|
||||
will be used. This arg is only valid if the field is annotated with as a `Literal`. For example,
|
||||
`Literal["choice1", "choice2", "choice3"]` with `ui_choice_labels={"choice1": "Choice 1", "choice2": "Choice 2",
|
||||
"choice3": "Choice 3"}` will render a dropdown with the labels "Choice 1", "Choice 2" and "Choice 3".
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
@@ -537,8 +681,6 @@ def InputField(
|
||||
field_kind=FieldKind.Input,
|
||||
)
|
||||
|
||||
if ui_type is not None:
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
if ui_component is not None:
|
||||
json_schema_extra_.ui_component = ui_component
|
||||
if ui_hidden is not None:
|
||||
@@ -547,6 +689,28 @@ def InputField(
|
||||
json_schema_extra_.ui_order = ui_order
|
||||
if ui_choice_labels is not None:
|
||||
json_schema_extra_.ui_choice_labels = ui_choice_labels
|
||||
if ui_model_base is not None:
|
||||
if isinstance(ui_model_base, list):
|
||||
json_schema_extra_.ui_model_base = ui_model_base
|
||||
else:
|
||||
json_schema_extra_.ui_model_base = [ui_model_base]
|
||||
if ui_model_type is not None:
|
||||
if isinstance(ui_model_type, list):
|
||||
json_schema_extra_.ui_model_type = ui_model_type
|
||||
else:
|
||||
json_schema_extra_.ui_model_type = [ui_model_type]
|
||||
if ui_model_variant is not None:
|
||||
if isinstance(ui_model_variant, list):
|
||||
json_schema_extra_.ui_model_variant = ui_model_variant
|
||||
else:
|
||||
json_schema_extra_.ui_model_variant = [ui_model_variant]
|
||||
if ui_model_format is not None:
|
||||
if isinstance(ui_model_format, list):
|
||||
json_schema_extra_.ui_model_format = ui_model_format
|
||||
else:
|
||||
json_schema_extra_.ui_model_format = [ui_model_format]
|
||||
if ui_type is not None:
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
@@ -648,20 +812,20 @@ def OutputField(
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization)
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
Args:
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
"""
|
||||
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
@@ -679,9 +843,9 @@ def OutputField(
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_type=ui_type,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@@ -4,9 +4,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
@@ -29,7 +30,10 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
description=FieldDescriptions.control_lora_model,
|
||||
title="Control LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlLoRa,
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
@@ -6,11 +6,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
@@ -57,7 +58,9 @@ class FluxControlNetInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
IPAdapterField,
|
||||
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import (
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -36,7 +37,10 @@ class FluxIPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.IPAdapter,
|
||||
)
|
||||
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
|
||||
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")
|
||||
|
||||
@@ -6,10 +6,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
@@ -36,7 +36,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
@@ -17,7 +17,7 @@ from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
@@ -46,23 +46,30 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.VAE,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
|
||||
@@ -18,7 +18,6 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
@@ -64,7 +63,8 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
redux_model: ModelIdentifierField = InputField(
|
||||
description="The FLUX Redux model to use.",
|
||||
title="FLUX Redux Model",
|
||||
ui_type=UIType.FluxReduxModel,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.FluxRedux,
|
||||
)
|
||||
downsampling_factor: int = InputField(
|
||||
ge=1,
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -85,7 +85,8 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.IPAdapter,
|
||||
)
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
|
||||
@@ -6,11 +6,12 @@ from pydantic import field_validator
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
vllm_model: ModelIdentifierField = InputField(
|
||||
title="LLaVA Model Type",
|
||||
description=FieldDescriptions.vllm_model,
|
||||
ui_type=UIType.LlavaOnevisionModel,
|
||||
ui_model_type=ModelType.LlavaOnevision,
|
||||
)
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
|
||||
@@ -53,7 +53,7 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.version import __version__
|
||||
|
||||
@@ -473,7 +473,6 @@ class MetadataToModelOutput(BaseInvocationOutput):
|
||||
model: ModelIdentifierField = OutputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
title="Model",
|
||||
ui_type=UIType.MainModel,
|
||||
)
|
||||
name: str = OutputField(description="Model Name", title="Name")
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
@@ -488,7 +487,6 @@ class MetadataToSDXLModelOutput(BaseInvocationOutput):
|
||||
model: ModelIdentifierField = OutputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
title="Model",
|
||||
ui_type=UIType.SDXLMainModel,
|
||||
)
|
||||
name: str = OutputField(description="Model Name", title="Name")
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
@@ -519,8 +517,7 @@ class MetadataToModelInvocation(BaseInvocation, WithMetadata):
|
||||
input=Input.Direct,
|
||||
)
|
||||
default_value: ModelIdentifierField = InputField(
|
||||
description="The default model to use if not found in the metadata",
|
||||
ui_type=UIType.MainModel,
|
||||
description="The default model to use if not found in the metadata", ui_model_type=ModelType.Main
|
||||
)
|
||||
|
||||
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
|
||||
@@ -575,7 +572,8 @@ class MetadataToSDXLModelInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
default_value: ModelIdentifierField = InputField(
|
||||
description="The default SDXL Model to use if not found in the metadata",
|
||||
ui_type=UIType.SDXLMainModel,
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
)
|
||||
|
||||
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -145,7 +145,7 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model - SD1.5",
|
||||
title="Main Model - SD1.5, SD2",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
@@ -153,7 +153,11 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
@@ -187,7 +191,10 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.StableDiffusion1,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -250,7 +257,9 @@ class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
@@ -332,7 +341,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -473,13 +485,26 @@ class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
|
||||
"vae_loader",
|
||||
title="VAE Model - SD1.5, SD2, SDXL, SD3, FLUX",
|
||||
tags=["vae", "model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
)
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
ui_model_base=[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.StableDiffusion3,
|
||||
BaseModelType.Flux,
|
||||
],
|
||||
ui_model_type=ModelType.VAE,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
|
||||
@@ -6,14 +6,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ClipVariantType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
@@ -39,36 +39,43 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.StableDiffusion3,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
ui_model_variant=ClipVariantType.L,
|
||||
)
|
||||
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
ui_model_variant=ClipVariantType.G,
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
default=None,
|
||||
ui_model_base=BaseModelType.StableDiffusion3,
|
||||
ui_model_type=ModelType.VAE,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@@ -29,7 +29,9 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
|
||||
description=FieldDescriptions.sdxl_main_model,
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@@ -67,7 +69,9 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
ui_model_base=BaseModelType.StableDiffusionXLRefiner,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@@ -19,6 +18,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
@@ -33,7 +33,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_to_image_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.spandrel_image_to_image_model,
|
||||
ui_type=UIType.SpandrelImageToImageModel,
|
||||
ui_model_type=ModelType.SpandrelImageToImage,
|
||||
)
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
|
||||
@@ -8,11 +8,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
@@ -60,7 +61,8 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
ui_order=-1,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.T2IAdapter,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelAllExceptCurrentResult,
|
||||
@@ -22,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
|
||||
@@ -135,6 +137,19 @@ class SessionQueueBase(ABC):
|
||||
"""Deletes all queue items except in-progress items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
"""Gets a page of session queue items. Do not remove."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
|
||||
@@ -34,6 +34,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
@@ -588,6 +589,59 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
|
||||
@@ -207,15 +207,24 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
||||
|
||||
|
||||
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||
state_dict: IPAdapterStateDict = {
|
||||
"ip_adapter": {},
|
||||
"image_proj": {},
|
||||
"adapter_modules": {}, # added for noobai-mark-ipa
|
||||
"image_proj_model": {}, # added for noobai-mark-ipa
|
||||
}
|
||||
|
||||
if ip_adapter_ckpt_path.suffix == ".safetensors":
|
||||
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||
for key in model.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||
elif key.startswith("ip_adapter."):
|
||||
if key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||
elif key.startswith("image_proj_model."):
|
||||
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = model[key]
|
||||
elif key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||
elif key.startswith("adapter_modules."):
|
||||
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = model[key]
|
||||
else:
|
||||
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||
else:
|
||||
|
||||
@@ -104,6 +104,7 @@
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
"clipboard": "Clipboard",
|
||||
"crop": "Crop",
|
||||
"on": "On",
|
||||
"off": "Off",
|
||||
"or": "or",
|
||||
@@ -242,7 +243,10 @@
|
||||
"resultSubtitle": "Choose how to handle the expanded prompt:",
|
||||
"replace": "Replace",
|
||||
"insert": "Insert",
|
||||
"discard": "Discard"
|
||||
"discard": "Discard",
|
||||
"noPromptHistory": "No prompt history recorded.",
|
||||
"noMatchingPrompts": "No matching prompts in history.",
|
||||
"toSwitchBetweenPrompts": "to switch between prompts."
|
||||
},
|
||||
"queue": {
|
||||
"queue": "Queue",
|
||||
@@ -480,6 +484,14 @@
|
||||
"title": "Focus Prompt",
|
||||
"desc": "Move cursor focus to the positive prompt."
|
||||
},
|
||||
"promptHistoryPrev": {
|
||||
"title": "Previous Prompt in History",
|
||||
"desc": "When the prompt is focused, move to the previous (older) prompt in your history."
|
||||
},
|
||||
"promptHistoryNext": {
|
||||
"title": "Next Prompt in History",
|
||||
"desc": "When the prompt is focused, move to the next (newer) prompt in your history."
|
||||
},
|
||||
"toggleLeftPanel": {
|
||||
"title": "Toggle Left Panel",
|
||||
"desc": "Show or hide the left panel."
|
||||
@@ -1258,6 +1270,7 @@
|
||||
"infillColorValue": "Fill Color",
|
||||
"info": "Info",
|
||||
"startingFrameImage": "Start Frame",
|
||||
"startingFrameImageAspectRatioWarning": "Image aspect ratio does not match the video aspect ratio ({{videoAspectRatio}}). This could lead to unexpected cropping during video generation.",
|
||||
"invoke": {
|
||||
"addingImagesTo": "Adding images to",
|
||||
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
|
||||
|
||||
@@ -131,7 +131,8 @@
|
||||
"notInstalled": "Non $t(common.installed)",
|
||||
"prevPage": "Pagina precedente",
|
||||
"nextPage": "Pagina successiva",
|
||||
"resetToDefaults": "Ripristina impostazioni predefinite"
|
||||
"resetToDefaults": "Ripristina impostazioni predefinite",
|
||||
"crop": "Ritaglia"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -278,6 +279,14 @@
|
||||
"selectVideoTab": {
|
||||
"title": "Seleziona la scheda Video",
|
||||
"desc": "Seleziona la scheda Video."
|
||||
},
|
||||
"promptHistoryPrev": {
|
||||
"title": "Prompt precedente nella cronologia",
|
||||
"desc": "Quando il prompt è attivo, passa al prompt precedente (più vecchio) nella cronologia."
|
||||
},
|
||||
"promptHistoryNext": {
|
||||
"title": "Prossimo prompt nella cronologia",
|
||||
"desc": "Quando il prompt è attivo, passa al prompt successivo (più recente) nella cronologia."
|
||||
}
|
||||
},
|
||||
"hotkeys": "Tasti di scelta rapida",
|
||||
@@ -875,7 +884,8 @@
|
||||
"video": "Video",
|
||||
"resolution": "Risoluzione",
|
||||
"downloadImage": "Scarica l'immagine",
|
||||
"showOptionsPanel": "Mostra pannello laterale (O o T)"
|
||||
"showOptionsPanel": "Mostra pannello laterale (O o T)",
|
||||
"startingFrameImageAspectRatioWarning": "Le proporzioni dell'immagine non corrispondono alle proporzioni del video ({{videoAspectRatio}}). Ciò potrebbe causare ritagli imprevisti durante la generazione del video."
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@@ -2095,7 +2105,10 @@
|
||||
"generateFromImage": "Genera prompt dall'immagine",
|
||||
"resultTitle": "Espansione del prompt completata",
|
||||
"resultSubtitle": "Scegli come gestire il prompt espanso:",
|
||||
"insert": "Inserisci"
|
||||
"insert": "Inserisci",
|
||||
"noPromptHistory": "Nessuna cronologia di prompt registrata.",
|
||||
"noMatchingPrompts": "Nessun prompt corrispondente nella cronologia.",
|
||||
"toSwitchBetweenPrompts": "per passare da un prompt all'altro."
|
||||
},
|
||||
"controlLayers": {
|
||||
"addLayer": "Aggiungi Livello",
|
||||
@@ -2791,7 +2804,8 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"items": [
|
||||
"Seleziona oggetto v2: selezione degli oggetti migliorata con input di punti e riquadri o prompt di testo.",
|
||||
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello."
|
||||
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello.",
|
||||
"Cronologia prompt: rivedi e richiama rapidamente i tuoi ultimi 100 prompt."
|
||||
],
|
||||
"watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente"
|
||||
},
|
||||
|
||||
@@ -2,6 +2,7 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { CropImageModal } from 'features/cropper/components/CropImageModal';
|
||||
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DeleteVideoModal } from 'features/deleteVideoModal/components/DeleteVideoModal';
|
||||
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
|
||||
@@ -58,6 +59,7 @@ export const GlobalModalIsolator = memo(() => {
|
||||
<CanvasPasteModal />
|
||||
</CanvasManagerProviderGate>
|
||||
<LoadWorkflowFromGraphModal />
|
||||
<CropImageModal />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -4,7 +4,6 @@ import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { sentImageToCanvas } from 'features/gallery/store/actions';
|
||||
@@ -164,7 +163,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
case 'generation':
|
||||
// Go to the generate tab, open the launchpad
|
||||
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
|
||||
store.dispatch(paramsReset());
|
||||
break;
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, open the launchpad
|
||||
|
||||
@@ -12,7 +12,13 @@ import {
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
getEntityIdentifier,
|
||||
isFLUXReduxConfig,
|
||||
isIPAdapterConfig,
|
||||
isRegionalGuidanceFLUXReduxConfig,
|
||||
isRegionalGuidanceIPAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
@@ -252,7 +258,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
|
||||
if (!isIPAdapterConfig(config)) {
|
||||
if (!isRegionalGuidanceIPAdapterConfig(config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -295,7 +301,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
|
||||
if (!isFLUXReduxConfig(config)) {
|
||||
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ export const RasterLayerAdjustmentsPanel = memo(() => {
|
||||
}
|
||||
const rect = adapter.transformer.getRelativeRect();
|
||||
try {
|
||||
await adapter.renderer.rasterize({ rect, replaceObjects: true });
|
||||
await adapter.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } });
|
||||
// Clear adjustments after baking
|
||||
dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null }));
|
||||
} catch {
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
|
||||
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
|
||||
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import type { CroppableImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToCroppableImage, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { Editor } from 'features/cropper/lib/editor';
|
||||
import { cropImageModalApi } from 'features/cropper/store';
|
||||
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
@@ -14,14 +18,14 @@ import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PiArrowCounterClockwiseBold, PiCropBold, PiRulerBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery, useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
|
||||
image: ImageWithDims | null;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
image: CroppableImageWithDims | null;
|
||||
onChangeImage: (croppableImage: CroppableImageWithDims | null) => void;
|
||||
dndTarget: T;
|
||||
dndTargetData: ReturnType<T['getData']>;
|
||||
};
|
||||
@@ -38,20 +42,28 @@ export const RefImageImage = memo(
|
||||
const isConnected = useStore($isConnected);
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isStaging = useCanvasIsStaging();
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
|
||||
const imageWithDims = image?.crop?.image ?? image?.original.image ?? null;
|
||||
const croppedImageDTOReq = useGetImageDTOQuery(image?.crop?.image?.image_name ?? skipToken);
|
||||
const originalImageDTOReq = useGetImageDTOQuery(image?.original.image.image_name ?? skipToken);
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
|
||||
const originalImageDTO = originalImageDTOReq.currentData;
|
||||
const croppedImageDTO = croppedImageDTOReq.currentData;
|
||||
const imageDTO = croppedImageDTO ?? originalImageDTO;
|
||||
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && isError) {
|
||||
if ((isConnected && croppedImageDTOReq.isError) || originalImageDTOReq.isError) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isError, isConnected]);
|
||||
}, [handleResetControlImage, isConnected, croppedImageDTOReq.isError, originalImageDTOReq.isError]);
|
||||
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
onChangeImage(imageDTO);
|
||||
onChangeImage(imageDTOToCroppableImage(imageDTO));
|
||||
},
|
||||
[onChangeImage]
|
||||
);
|
||||
@@ -70,13 +82,67 @@ export const RefImageImage = memo(
|
||||
}
|
||||
}, [imageDTO, isStaging, store, tab]);
|
||||
|
||||
const edit = useCallback(() => {
|
||||
if (!originalImageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We will create a new editor instance each time the user wants to edit
|
||||
const editor = new Editor();
|
||||
|
||||
// When the user applies the crop, we will upload the cropped image and store the applied crop box so if the user
|
||||
// re-opens the editor they see the same crop
|
||||
const onApplyCrop = async () => {
|
||||
const box = editor.getCropBox();
|
||||
if (objectEquals(box, image?.crop?.box)) {
|
||||
// If the box hasn't changed, don't do anything
|
||||
return;
|
||||
}
|
||||
if (!box || objectEquals(box, { x: 0, y: 0, width: originalImageDTO.width, height: originalImageDTO.height })) {
|
||||
// There is a crop applied but it is the whole iamge - revert to original image
|
||||
onChangeImage(imageDTOToCroppableImage(originalImageDTO));
|
||||
return;
|
||||
}
|
||||
const blob = await editor.exportImage('blob');
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
|
||||
const newCroppedImageDTO = await uploadImage({
|
||||
file,
|
||||
is_intermediate: true,
|
||||
image_category: 'user',
|
||||
}).unwrap();
|
||||
|
||||
onChangeImage(
|
||||
imageDTOToCroppableImage(originalImageDTO, {
|
||||
image: imageDTOToImageWithDims(newCroppedImageDTO),
|
||||
box,
|
||||
ratio: editor.getCropAspectRatio(),
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
const onReady = async () => {
|
||||
const initial = image?.crop ? { cropBox: image.crop.box, aspectRatio: image.crop.ratio } : undefined;
|
||||
// Load the image into the editor and open the modal once it's ready
|
||||
await editor.loadImage(originalImageDTO.image_url, initial);
|
||||
};
|
||||
|
||||
cropImageModalApi.open({ editor, onApplyCrop, onReady });
|
||||
}, [image?.crop, onChangeImage, originalImageDTO, uploadImage]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
|
||||
<Flex
|
||||
position="relative"
|
||||
w="full"
|
||||
h="full"
|
||||
alignItems="center"
|
||||
data-error={!imageDTO && !imageWithDims?.image_name}
|
||||
>
|
||||
{!imageDTO && (
|
||||
<UploadImageIconButton
|
||||
w="full"
|
||||
h="full"
|
||||
isError={!imageDTO && !image?.image_name}
|
||||
isError={!imageDTO && !imageWithDims?.image_name}
|
||||
onUpload={onUpload}
|
||||
fontSize={36}
|
||||
/>
|
||||
@@ -99,6 +165,15 @@ export const RefImageImage = memo(
|
||||
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineStart={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={edit}
|
||||
icon={<PiCropBold size={16} />}
|
||||
tooltip={t('common.crop')}
|
||||
isDisabled={!imageDTO}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
selectRefImageEntityIds,
|
||||
selectSelectedRefEntityId,
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
|
||||
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
@@ -92,7 +92,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
({
|
||||
onUpload: (imageDTO: ImageDTO) => {
|
||||
const config = getDefaultRefImageConfig(getState);
|
||||
config.image = imageDTOToImageWithDims(imageDTO);
|
||||
config.image = imageDTOToCroppableImage(imageDTO);
|
||||
dispatch(refImageAdded({ overrides: { config } }));
|
||||
},
|
||||
allowMultiple: false,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { round } from 'es-toolkit/compat';
|
||||
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
|
||||
@@ -15,7 +14,7 @@ import { isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
@@ -72,7 +71,8 @@ export const RefImagePreview = memo(() => {
|
||||
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
|
||||
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
|
||||
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
|
||||
const { data: imageDTO } = useGetImageDTOQuery(entity.config.image?.image_name ?? skipToken);
|
||||
|
||||
const imageDTO = useImageDTOFromCroppableImage(entity.config.image);
|
||||
|
||||
const sx = useMemo(() => {
|
||||
if (!isIPAdapterConfig(entity.config)) {
|
||||
@@ -145,7 +145,7 @@ export const RefImagePreview = memo(() => {
|
||||
overflow="hidden"
|
||||
>
|
||||
<Image
|
||||
src={imageDTO?.thumbnail_url}
|
||||
src={imageDTO?.image_url}
|
||||
objectFit="contain"
|
||||
aspectRatio="1/1"
|
||||
height={imageDTO?.height}
|
||||
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import type {
|
||||
CLIPVisionModelV2,
|
||||
CroppableImageWithDims,
|
||||
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
|
||||
IPMethodV2,
|
||||
} from 'features/controlLayers/store/types';
|
||||
@@ -42,7 +43,6 @@ import type {
|
||||
ChatGPT4oModelConfig,
|
||||
FLUXKontextModelConfig,
|
||||
FLUXReduxModelConfig,
|
||||
ImageDTO,
|
||||
IPAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
@@ -104,15 +104,19 @@ const RefImageSettingsContent = memo(() => {
|
||||
);
|
||||
|
||||
const onChangeImage = useCallback(
|
||||
(imageDTO: ImageDTO | null) => {
|
||||
dispatch(refImageImageChanged({ id, imageDTO }));
|
||||
(croppableImage: CroppableImageWithDims | null) => {
|
||||
dispatch(refImageImageChanged({ id, croppableImage }));
|
||||
},
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
() => setGlobalReferenceImageDndTarget.getData({ id }, config.image?.image_name),
|
||||
[id, config.image?.image_name]
|
||||
() =>
|
||||
setGlobalReferenceImageDndTarget.getData(
|
||||
{ id },
|
||||
config.image?.crop?.image.image_name ?? config.image?.original.image.image_name
|
||||
),
|
||||
[id, config.image?.crop?.image.image_name, config.image?.original.image.image_name]
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
@@ -6,7 +6,6 @@ import { FLUXReduxImageInfluence } from 'features/controlLayers/components/commo
|
||||
import { IPAdapterCLIPVisionModel } from 'features/controlLayers/components/common/IPAdapterCLIPVisionModel';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/RefImage/IPAdapterMethod';
|
||||
import { RefImageImage } from 'features/controlLayers/components/RefImage/RefImageImage';
|
||||
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
|
||||
import { RegionalReferenceImageModel } from 'features/controlLayers/components/RegionalGuidance/RegionalReferenceImageModel';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -37,6 +36,8 @@ import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
|
||||
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { RegionalGuidanceRefImageImage } from './RegionalGuidanceRefImageImage';
|
||||
|
||||
type Props = {
|
||||
referenceImageId: string;
|
||||
};
|
||||
@@ -114,7 +115,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
{ entityIdentifier, referenceImageId },
|
||||
config.image?.image_name
|
||||
),
|
||||
[entityIdentifier, config.image?.image_name, referenceImageId]
|
||||
[entityIdentifier, config.image, referenceImageId]
|
||||
);
|
||||
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
|
||||
@@ -170,7 +171,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
</Flex>
|
||||
)}
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
|
||||
<RefImageImage
|
||||
<RegionalGuidanceRefImageImage
|
||||
image={config.image}
|
||||
onChangeImage={onChangeImage}
|
||||
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
|
||||
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
|
||||
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import type { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
type Props = {
|
||||
image: ImageWithDims | null;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
dndTarget: typeof setRegionalGuidanceReferenceImageDndTarget;
|
||||
dndTargetData: ReturnType<(typeof setRegionalGuidanceReferenceImageDndTarget)['getData']>;
|
||||
};
|
||||
|
||||
export const RegionalGuidanceRefImageImage = memo(({ image, onChangeImage, dndTarget, dndTargetData }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const store = useAppStore();
|
||||
const isConnected = useStore($isConnected);
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isStaging = useCanvasIsStaging();
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && isError) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isError, isConnected]);
|
||||
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
onChangeImage(imageDTO);
|
||||
},
|
||||
[onChangeImage]
|
||||
);
|
||||
|
||||
const recallSizeAndOptimize = useCallback(() => {
|
||||
if (!imageDTO || (tab === 'canvas' && isStaging)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = imageDTO;
|
||||
if (tab === 'canvas') {
|
||||
store.dispatch(bboxSizeRecalled({ width, height }));
|
||||
store.dispatch(bboxSizeOptimized());
|
||||
} else if (tab === 'generate') {
|
||||
store.dispatch(sizeRecalled({ width, height }));
|
||||
store.dispatch(sizeOptimized());
|
||||
}
|
||||
}, [imageDTO, isStaging, store, tab]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
|
||||
{!imageDTO && (
|
||||
<UploadImageIconButton
|
||||
w="full"
|
||||
h="full"
|
||||
isError={!imageDTO && !image?.image_name}
|
||||
onUpload={onUpload}
|
||||
fontSize={36}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('common.reset')}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={recallSizeAndOptimize}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={t('parameters.useSize')}
|
||||
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalGuidanceRefImageImage.displayName = 'RegionalGuidanceRefImageImage';
|
||||
@@ -30,6 +30,7 @@ import type {
|
||||
FluxKontextReferenceImageConfig,
|
||||
Gemini2_5ReferenceImageConfig,
|
||||
IPAdapterConfig,
|
||||
RegionalGuidanceIPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import {
|
||||
@@ -38,6 +39,7 @@ import {
|
||||
initialFluxKontextReferenceImage,
|
||||
initialGemini2_5ReferenceImage,
|
||||
initialIPAdapter,
|
||||
initialRegionalGuidanceIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
@@ -125,7 +127,7 @@ export const getDefaultRefImageConfig = (
|
||||
return config;
|
||||
};
|
||||
|
||||
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): IPAdapterConfig => {
|
||||
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): RegionalGuidanceIPAdapterConfig => {
|
||||
// Regional guidance ref images do not support ChatGPT-4o, so we always return the IP Adapter config.
|
||||
const state = getState();
|
||||
|
||||
@@ -138,7 +140,7 @@ export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState):
|
||||
const modelConfig = ipAdapterModelConfigs.find((m) => m.base === base);
|
||||
|
||||
// Clone the initial IP Adapter config and set the model if available.
|
||||
const config = deepClone(initialIPAdapter);
|
||||
const config = deepClone(initialRegionalGuidanceIPAdapter);
|
||||
|
||||
if (modelConfig) {
|
||||
config.model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
@@ -32,7 +32,12 @@ import type {
|
||||
RefImageState,
|
||||
RegionalGuidanceRefImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import {
|
||||
imageDTOToCroppableImage,
|
||||
imageDTOToImageObject,
|
||||
imageDTOToImageWithDims,
|
||||
initialControlNet,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -209,7 +214,7 @@ export const useNewGlobalReferenceImageFromBbox = () => {
|
||||
const overrides: Partial<RefImageState> = {
|
||||
config: {
|
||||
...getDefaultRefImageConfig(getState),
|
||||
image: imageDTOToImageWithDims(imageDTO),
|
||||
image: imageDTOToCroppableImage(imageDTO),
|
||||
},
|
||||
};
|
||||
dispatch(refImageAdded({ overrides }));
|
||||
@@ -312,7 +317,7 @@ export const usePullBboxIntoGlobalReferenceImage = (id: string) => {
|
||||
|
||||
const arg = useMemo<UseSaveCanvasArg>(() => {
|
||||
const onSave = (imageDTO: ImageDTO, _: Rect) => {
|
||||
dispatch(refImageImageChanged({ id, imageDTO }));
|
||||
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -82,10 +82,10 @@ import {
|
||||
IMAGEN_ASPECT_RATIOS,
|
||||
isChatGPT4oAspectRatioID,
|
||||
isFluxKontextAspectRatioID,
|
||||
isFLUXReduxConfig,
|
||||
isGemini2_5AspectRatioID,
|
||||
isImagenAspectRatioID,
|
||||
isIPAdapterConfig,
|
||||
isRegionalGuidanceFLUXReduxConfig,
|
||||
isRegionalGuidanceIPAdapterConfig,
|
||||
zCanvasState,
|
||||
} from './types';
|
||||
import {
|
||||
@@ -99,6 +99,7 @@ import {
|
||||
initialControlNet,
|
||||
initialFLUXRedux,
|
||||
initialIPAdapter,
|
||||
initialRegionalGuidanceIPAdapter,
|
||||
initialT2IAdapter,
|
||||
makeDefaultRasterLayerAdjustments,
|
||||
} from './util';
|
||||
@@ -804,7 +805,7 @@ const slice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
const config = { id: referenceImageId, config: deepClone(initialIPAdapter) };
|
||||
const config = { id: referenceImageId, config: deepClone(initialRegionalGuidanceIPAdapter) };
|
||||
merge(config, overrides);
|
||||
entity.referenceImages.push(config);
|
||||
},
|
||||
@@ -847,7 +848,7 @@ const slice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (!isIPAdapterConfig(referenceImage.config)) {
|
||||
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -864,7 +865,7 @@ const slice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (!isIPAdapterConfig(referenceImage.config)) {
|
||||
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
|
||||
return;
|
||||
}
|
||||
referenceImage.config.beginEndStepPct = beginEndStepPct;
|
||||
@@ -880,7 +881,7 @@ const slice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (!isIPAdapterConfig(referenceImage.config)) {
|
||||
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
|
||||
return;
|
||||
}
|
||||
referenceImage.config.method = method;
|
||||
@@ -899,7 +900,7 @@ const slice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (!isFLUXReduxConfig(referenceImage.config)) {
|
||||
if (!isRegionalGuidanceFLUXReduxConfig(referenceImage.config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -928,7 +929,7 @@ const slice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (isIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
|
||||
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
|
||||
// Switching from ip_adapter to flux_redux
|
||||
referenceImage.config = {
|
||||
...initialFLUXRedux,
|
||||
@@ -938,7 +939,7 @@ const slice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (isFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
|
||||
if (isRegionalGuidanceFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
|
||||
// Switching from flux_redux to ip_adapter
|
||||
referenceImage.config = {
|
||||
...initialIPAdapter,
|
||||
@@ -948,7 +949,7 @@ const slice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (isIPAdapterConfig(referenceImage.config)) {
|
||||
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
|
||||
referenceImage.config.model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
@@ -971,7 +972,7 @@ const slice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (!isIPAdapterConfig(referenceImage.config)) {
|
||||
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
|
||||
return;
|
||||
}
|
||||
referenceImage.config.clipVisionModel = clipVisionModel;
|
||||
|
||||
@@ -199,11 +199,7 @@ const slice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (state.positivePromptHistory.includes(prompt)) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.positivePromptHistory.unshift(prompt);
|
||||
state.positivePromptHistory = [prompt, ...state.positivePromptHistory.filter((p) => p !== prompt)];
|
||||
|
||||
if (state.positivePromptHistory.length > MAX_POSITIVE_PROMPT_HISTORY) {
|
||||
state.positivePromptHistory = state.positivePromptHistory.slice(0, MAX_POSITIVE_PROMPT_HISTORY);
|
||||
|
||||
@@ -6,13 +6,16 @@ import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
|
||||
import type {
|
||||
CroppableImageWithDims,
|
||||
FLUXReduxImageInfluence,
|
||||
RefImagesState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ChatGPT4oModelConfig,
|
||||
FLUXKontextModelConfig,
|
||||
FLUXReduxModelConfig,
|
||||
ImageDTO,
|
||||
IPAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -22,7 +25,6 @@ import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
|
||||
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
|
||||
import {
|
||||
getReferenceImageState,
|
||||
imageDTOToImageWithDims,
|
||||
initialChatGPT4oReferenceImage,
|
||||
initialFluxKontextReferenceImage,
|
||||
initialFLUXRedux,
|
||||
@@ -65,13 +67,13 @@ const slice = createSlice({
|
||||
state.entities.push(...entities);
|
||||
}
|
||||
},
|
||||
refImageImageChanged: (state, action: PayloadActionWithId<{ imageDTO: ImageDTO | null }>) => {
|
||||
const { id, imageDTO } = action.payload;
|
||||
refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => {
|
||||
const { id, croppableImage } = action.payload;
|
||||
const entity = selectRefImageEntity(state, id);
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
entity.config.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
|
||||
entity.config.image = croppableImage;
|
||||
},
|
||||
refImageIPAdapterMethodChanged: (state, action: PayloadActionWithId<{ method: IPMethodV2 }>) => {
|
||||
const { id, method } = action.payload;
|
||||
|
||||
@@ -37,6 +37,45 @@ export const zImageWithDims = z.object({
|
||||
});
|
||||
export type ImageWithDims = z.infer<typeof zImageWithDims>;
|
||||
|
||||
const zCropBox = z.object({
|
||||
x: z.number().min(0),
|
||||
y: z.number().min(0),
|
||||
width: z.number().positive(),
|
||||
height: z.number().positive(),
|
||||
});
|
||||
// This new schema is an extension of zImageWithDims, with an optional crop field.
|
||||
//
|
||||
// When we added cropping support to certain entities (e.g. Ref Images, video Starting Frame Image), we changed
|
||||
// their schemas from using zImageWithDims to this new schema. To support loading pre-existing entities that
|
||||
// were created before cropping was supported, we can use zod's preprocess to transform old data into the new format.
|
||||
// Its essentially a data migration step.
|
||||
//
|
||||
// This parsing happens currently in two places:
|
||||
// - Recalling metadata.
|
||||
// - Loading/rehydrating persisted client state from storage.
|
||||
export const zCroppableImageWithDims = z.preprocess(
|
||||
(val) => {
|
||||
try {
|
||||
const imageWithDims = zImageWithDims.parse(val);
|
||||
const migrated = { original: { image: deepClone(imageWithDims) } };
|
||||
return migrated;
|
||||
} catch {
|
||||
return val;
|
||||
}
|
||||
},
|
||||
z.object({
|
||||
original: z.object({ image: zImageWithDims }),
|
||||
crop: z
|
||||
.object({
|
||||
box: zCropBox,
|
||||
ratio: z.number().gt(0).nullable(),
|
||||
image: zImageWithDims,
|
||||
})
|
||||
.optional(),
|
||||
})
|
||||
);
|
||||
export type CroppableImageWithDims = z.infer<typeof zCroppableImageWithDims>;
|
||||
|
||||
const zImageWithDimsDataURL = z.object({
|
||||
dataURL: z.string(),
|
||||
width: z.number().int().positive(),
|
||||
@@ -235,7 +274,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
|
||||
|
||||
const zIPAdapterConfig = z.object({
|
||||
type: z.literal('ip_adapter'),
|
||||
image: zImageWithDims.nullable(),
|
||||
image: zCroppableImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
@@ -244,21 +283,39 @@ const zIPAdapterConfig = z.object({
|
||||
});
|
||||
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
|
||||
|
||||
const zRegionalGuidanceIPAdapterConfig = z.object({
|
||||
type: z.literal('ip_adapter'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
method: zIPMethodV2,
|
||||
clipVisionModel: zCLIPVisionModelV2,
|
||||
});
|
||||
export type RegionalGuidanceIPAdapterConfig = z.infer<typeof zRegionalGuidanceIPAdapterConfig>;
|
||||
|
||||
const zFLUXReduxImageInfluence = z.enum(['lowest', 'low', 'medium', 'high', 'highest']);
|
||||
export const isFLUXReduxImageInfluence = (v: unknown): v is FLUXReduxImageInfluence =>
|
||||
zFLUXReduxImageInfluence.safeParse(v).success;
|
||||
export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
|
||||
const zFLUXReduxConfig = z.object({
|
||||
type: z.literal('flux_redux'),
|
||||
image: zImageWithDims.nullable(),
|
||||
image: zCroppableImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
|
||||
});
|
||||
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
|
||||
const zRegionalGuidanceFLUXReduxConfig = z.object({
|
||||
type: z.literal('flux_redux'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
|
||||
});
|
||||
type RegionalGuidanceFLUXReduxConfig = z.infer<typeof zRegionalGuidanceFLUXReduxConfig>;
|
||||
|
||||
const zChatGPT4oReferenceImageConfig = z.object({
|
||||
type: z.literal('chatgpt_4o_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
image: zCroppableImageWithDims.nullable(),
|
||||
/**
|
||||
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
|
||||
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
|
||||
@@ -270,14 +327,14 @@ export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceIm
|
||||
|
||||
const zGemini2_5ReferenceImageConfig = z.object({
|
||||
type: z.literal('gemini_2_5_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
image: zCroppableImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
});
|
||||
export type Gemini2_5ReferenceImageConfig = z.infer<typeof zGemini2_5ReferenceImageConfig>;
|
||||
|
||||
const zFluxKontextReferenceImageConfig = z.object({
|
||||
type: z.literal('flux_kontext_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
image: zCroppableImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
});
|
||||
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
|
||||
@@ -307,6 +364,7 @@ export const isIPAdapterConfig = (config: RefImageState['config']): config is IP
|
||||
|
||||
export const isFLUXReduxConfig = (config: RefImageState['config']): config is FLUXReduxConfig =>
|
||||
config.type === 'flux_redux';
|
||||
|
||||
export const isChatGPT4oReferenceImageConfig = (
|
||||
config: RefImageState['config']
|
||||
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
|
||||
@@ -326,10 +384,18 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
|
||||
|
||||
const zRegionalGuidanceRefImageState = z.object({
|
||||
id: zId,
|
||||
config: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
|
||||
config: z.discriminatedUnion('type', [zRegionalGuidanceIPAdapterConfig, zRegionalGuidanceFLUXReduxConfig]),
|
||||
});
|
||||
export type RegionalGuidanceRefImageState = z.infer<typeof zRegionalGuidanceRefImageState>;
|
||||
|
||||
export const isRegionalGuidanceIPAdapterConfig = (
|
||||
config: RegionalGuidanceRefImageState['config']
|
||||
): config is RegionalGuidanceIPAdapterConfig => config.type === 'ip_adapter';
|
||||
|
||||
export const isRegionalGuidanceFLUXReduxConfig = (
|
||||
config: RegionalGuidanceRefImageState['config']
|
||||
): config is RegionalGuidanceFLUXReduxConfig => config.type === 'flux_redux';
|
||||
|
||||
const zCanvasRegionalGuidanceState = zCanvasEntityBase.extend({
|
||||
type: z.literal('regional_guidance'),
|
||||
position: zCoordinate,
|
||||
|
||||
@@ -10,6 +10,7 @@ import type {
|
||||
ChatGPT4oReferenceImageConfig,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
CroppableImageWithDims,
|
||||
FluxKontextReferenceImageConfig,
|
||||
FLUXReduxConfig,
|
||||
Gemini2_5ReferenceImageConfig,
|
||||
@@ -17,6 +18,7 @@ import type {
|
||||
IPAdapterConfig,
|
||||
RasterLayerAdjustments,
|
||||
RefImageState,
|
||||
RegionalGuidanceIPAdapterConfig,
|
||||
RgbColor,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
@@ -45,6 +47,21 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
|
||||
height,
|
||||
});
|
||||
|
||||
export const imageDTOToCroppableImage = (
|
||||
originalImageDTO: ImageDTO,
|
||||
crop?: CroppableImageWithDims['crop']
|
||||
): CroppableImageWithDims => {
|
||||
const { image_name, width, height } = originalImageDTO;
|
||||
const val: CroppableImageWithDims = {
|
||||
original: { image: { image_name, width, height } },
|
||||
};
|
||||
if (crop) {
|
||||
val.crop = deepClone(crop);
|
||||
}
|
||||
|
||||
return val;
|
||||
};
|
||||
|
||||
export const imageDTOToImageField = ({ image_name }: ImageDTO): ImageField => ({ image_name });
|
||||
|
||||
const DEFAULT_RG_MASK_FILL_COLORS: RgbColor[] = [
|
||||
@@ -79,6 +96,15 @@ export const initialIPAdapter: IPAdapterConfig = {
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
};
|
||||
export const initialRegionalGuidanceIPAdapter: RegionalGuidanceIPAdapterConfig = {
|
||||
type: 'ip_adapter',
|
||||
image: null,
|
||||
model: null,
|
||||
beginEndStepPct: [0, 1],
|
||||
method: 'full',
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
};
|
||||
export const initialFLUXRedux: FLUXReduxConfig = {
|
||||
type: 'flux_redux',
|
||||
image: null,
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Select,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { AspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { ASPECT_RATIO_MAP, isAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import type { CropBox } from 'features/cropper/lib/editor';
|
||||
import { cropImageModalApi, type CropImageModalState } from 'features/cropper/store';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import React, { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
editor: CropImageModalState['editor'];
|
||||
onApplyCrop: CropImageModalState['onApplyCrop'];
|
||||
onReady: CropImageModalState['onReady'];
|
||||
};
|
||||
|
||||
const getAspectRatioString = (ratio: number | null): AspectRatioID => {
|
||||
if (!ratio) {
|
||||
return 'Free';
|
||||
}
|
||||
const entries = objectEntries(ASPECT_RATIO_MAP);
|
||||
for (const [key, value] of entries) {
|
||||
if (value.ratio === ratio) {
|
||||
return key;
|
||||
}
|
||||
}
|
||||
return 'Free';
|
||||
};
|
||||
|
||||
export const CropImageEditor = memo(({ editor, onApplyCrop, onReady }: Props) => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [zoom, setZoom] = useState(100);
|
||||
const [cropBox, setCropBox] = useState<CropBox | null>(null);
|
||||
const [aspectRatio, setAspectRatio] = useState<string>('free');
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
|
||||
const [uploadImage] = useUploadImageMutation({ fixedCacheKey: 'editorContainer' });
|
||||
|
||||
const setup = useCallback(
|
||||
async (container: HTMLDivElement) => {
|
||||
editor.init(container);
|
||||
editor.onZoomChange((zoom) => {
|
||||
setZoom(zoom);
|
||||
});
|
||||
editor.onCropBoxChange((crop) => {
|
||||
setCropBox(crop);
|
||||
});
|
||||
editor.onAspectRatioChange((ratio) => {
|
||||
setAspectRatio(getAspectRatioString(ratio));
|
||||
});
|
||||
await onReady();
|
||||
editor.fitToContainer();
|
||||
},
|
||||
[editor, onReady]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
setup(container);
|
||||
const handleResize = () => {
|
||||
editor.resize(container.clientWidth, container.clientHeight);
|
||||
};
|
||||
|
||||
const resizeObserver = new ResizeObserver(handleResize);
|
||||
resizeObserver.observe(container);
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [editor, setup]);
|
||||
|
||||
const handleAspectRatioChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLSelectElement>) => {
|
||||
const newRatio = e.target.value;
|
||||
if (!isAspectRatioID(newRatio)) {
|
||||
return;
|
||||
}
|
||||
setAspectRatio(newRatio);
|
||||
|
||||
if (newRatio === 'Free') {
|
||||
editor.setCropAspectRatio(null);
|
||||
} else {
|
||||
editor.setCropAspectRatio(ASPECT_RATIO_MAP[newRatio]?.ratio ?? null);
|
||||
}
|
||||
},
|
||||
[editor]
|
||||
);
|
||||
|
||||
const handleResetCrop = useCallback(() => {
|
||||
editor.resetCrop();
|
||||
}, [editor]);
|
||||
|
||||
const handleApplyCrop = useCallback(async () => {
|
||||
await onApplyCrop();
|
||||
cropImageModalApi.close();
|
||||
}, [onApplyCrop]);
|
||||
|
||||
const handleCancelCrop = useCallback(() => {
|
||||
cropImageModalApi.close();
|
||||
}, []);
|
||||
|
||||
const handleExport = useCallback(async () => {
|
||||
try {
|
||||
const blob = await editor.exportImage('blob');
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
|
||||
await uploadImage({
|
||||
file,
|
||||
is_intermediate: false,
|
||||
image_category: 'user',
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}).unwrap();
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.message.includes('tainted')) {
|
||||
alert(
|
||||
'Cannot export image: The image is from a different domain (CORS issue). To fix this:\n\n1. Load images from the same domain\n2. Use images from CORS-enabled sources\n3. Upload a local image file instead'
|
||||
);
|
||||
} else {
|
||||
alert(`Export failed: ${err instanceof Error ? err.message : String(err)}`);
|
||||
}
|
||||
}
|
||||
}, [autoAddBoardId, editor, uploadImage]);
|
||||
|
||||
const zoomIn = useCallback(() => {
|
||||
editor.zoomIn();
|
||||
}, [editor]);
|
||||
|
||||
const zoomOut = useCallback(() => {
|
||||
editor.zoomOut();
|
||||
}, [editor]);
|
||||
|
||||
const fitToContainer = useCallback(() => {
|
||||
editor.fitToContainer();
|
||||
}, [editor]);
|
||||
|
||||
const resetView = useCallback(() => {
|
||||
editor.resetView();
|
||||
}, [editor]);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column" gap={4}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<FormControl flex={1}>
|
||||
<FormLabel>Aspect Ratio:</FormLabel>
|
||||
<Select size="sm" value={aspectRatio} onChange={handleAspectRatioChange} w={32}>
|
||||
<option value="Free">Free</option>
|
||||
<option value="16:9">16:9</option>
|
||||
<option value="3:2">3:2</option>
|
||||
<option value="4:3">4:3</option>
|
||||
<option value="1:1">1:1</option>
|
||||
<option value="3:4">3:4</option>
|
||||
<option value="2:3">2:3</option>
|
||||
<option value="9:16">9:16</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
|
||||
<Spacer />
|
||||
|
||||
<ButtonGroup size="sm" isAttached={false}>
|
||||
<Button onClick={fitToContainer}>Fit View</Button>
|
||||
<Button onClick={resetView}>Reset View</Button>
|
||||
<Button onClick={zoomIn}>Zoom In</Button>
|
||||
<Button onClick={zoomOut}>Zoom Out</Button>
|
||||
</ButtonGroup>
|
||||
|
||||
<Spacer />
|
||||
|
||||
<ButtonGroup size="sm" isAttached={false}>
|
||||
<Button onClick={handleApplyCrop}>Apply</Button>
|
||||
<Button onClick={handleResetCrop}>Reset</Button>
|
||||
<Button onClick={handleCancelCrop}>Cancel</Button>
|
||||
<Button onClick={handleExport}>Save to Assets</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
|
||||
<Flex position="relative" w="full" h="full" bg="base.900">
|
||||
<Flex position="absolute" inset={0} ref={containerRef} />
|
||||
</Flex>
|
||||
|
||||
<Flex gap={2} color="base.300">
|
||||
<Text>Mouse wheel: Zoom</Text>
|
||||
<Divider orientation="vertical" />
|
||||
<Text>Space + Drag: Pan</Text>
|
||||
<Divider orientation="vertical" />
|
||||
<Text>Drag crop box or handles to adjust</Text>
|
||||
{cropBox && (
|
||||
<>
|
||||
<Divider orientation="vertical" />
|
||||
<Text>
|
||||
X: {Math.round(cropBox.x)}, Y: {Math.round(cropBox.y)}, Width: {Math.round(cropBox.width)}, Height:{' '}
|
||||
{Math.round(cropBox.height)}
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
<Spacer key="help-spacer" />
|
||||
<Text key="help-zoom">Zoom: {Math.round(zoom * 100)}%</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CropImageEditor.displayName = 'CropImageEditor';
|
||||
@@ -0,0 +1,29 @@
|
||||
import { Modal, ModalBody, ModalContent, ModalHeader, ModalOverlay } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { cropImageModalApi } from 'features/cropper/store';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { CropImageEditor } from './CropImageEditor';
|
||||
|
||||
export const CropImageModal = memo(() => {
|
||||
const state = useStore(cropImageModalApi.$state);
|
||||
|
||||
if (!state) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
// This modal is always open when this component is rendered
|
||||
<Modal isOpen={true} onClose={cropImageModalApi.close} isCentered useInert={false} size="full">
|
||||
<ModalOverlay />
|
||||
<ModalContent minH="unset" minW="unset" maxH="90vh" maxW="90vw" w="full" h="full" borderRadius="base">
|
||||
<ModalHeader>Crop Image</ModalHeader>
|
||||
<ModalBody px={4} pb={4} pt={0}>
|
||||
<CropImageEditor editor={state.editor} onApplyCrop={state.onApplyCrop} onReady={state.onReady} />
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
CropImageModal.displayName = 'CropImageModal';
|
||||
1557
invokeai/frontend/web/src/features/cropper/lib/editor.ts
Normal file
1557
invokeai/frontend/web/src/features/cropper/lib/editor.ts
Normal file
File diff suppressed because it is too large
Load Diff
26
invokeai/frontend/web/src/features/cropper/store/index.ts
Normal file
26
invokeai/frontend/web/src/features/cropper/store/index.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { Editor } from 'features/cropper/lib/editor';
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
export type CropImageModalState = {
|
||||
editor: Editor;
|
||||
onApplyCrop: () => Promise<void> | void;
|
||||
onReady: () => Promise<void> | void;
|
||||
};
|
||||
|
||||
const $state = atom<CropImageModalState | null>(null);
|
||||
|
||||
const open = (state: CropImageModalState) => {
|
||||
$state.set(state);
|
||||
};
|
||||
|
||||
const close = () => {
|
||||
const state = $state.get();
|
||||
state?.editor.destroy();
|
||||
$state.set(null);
|
||||
};
|
||||
|
||||
export const cropImageModalApi = {
|
||||
$state,
|
||||
open,
|
||||
close,
|
||||
};
|
||||
@@ -236,8 +236,11 @@ const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, image
|
||||
|
||||
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, image_name: string) => {
|
||||
selectReferenceImageEntities(state).forEach((entity) => {
|
||||
if (entity.config.image?.image_name === image_name) {
|
||||
dispatch(refImageImageChanged({ id: entity.id, imageDTO: null }));
|
||||
if (
|
||||
entity.config.image?.original.image.image_name === image_name ||
|
||||
entity.config.image?.crop?.image.image_name === image_name
|
||||
) {
|
||||
dispatch(refImageImageChanged({ id: entity.id, croppableImage: null }));
|
||||
}
|
||||
});
|
||||
};
|
||||
@@ -284,7 +287,10 @@ export const getImageUsage = (
|
||||
|
||||
const isUpscaleImage = upscale.upscaleInitialImage?.image_name === image_name;
|
||||
|
||||
const isReferenceImage = refImages.entities.some(({ config }) => config.image?.image_name === image_name);
|
||||
const isReferenceImage = refImages.entities.some(
|
||||
({ config }) =>
|
||||
config.image?.original.image.image_name === image_name || config.image?.crop?.image.image_name === image_name
|
||||
);
|
||||
|
||||
const isRasterLayerImage = canvas.rasterLayers.entities.some(({ objects }) =>
|
||||
objects.some((obj) => obj.type === 'image' && 'image_name' in obj.image && obj.image.image_name === image_name)
|
||||
|
||||
@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
export const imageButtonSx: SystemStyleObject = {
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
@@ -31,7 +31,7 @@ export const DndImageIcon = memo((props: Props) => {
|
||||
aria-label={tooltip}
|
||||
icon={icon}
|
||||
variant="link"
|
||||
sx={sx}
|
||||
sx={imageButtonSx}
|
||||
data-testid={tooltip}
|
||||
{...rest}
|
||||
/>
|
||||
|
||||
@@ -4,7 +4,7 @@ import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerH
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
|
||||
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
|
||||
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
@@ -211,7 +211,7 @@ export const addGlobalReferenceImageDndTarget: DndTarget<
|
||||
handler: ({ sourceData, dispatch, getState }) => {
|
||||
const { imageDTO } = sourceData.payload;
|
||||
const config = getDefaultRefImageConfig(getState);
|
||||
config.image = imageDTOToImageWithDims(imageDTO);
|
||||
config.image = imageDTOToCroppableImage(imageDTO);
|
||||
dispatch(refImageAdded({ overrides: { config } }));
|
||||
},
|
||||
};
|
||||
@@ -641,7 +641,7 @@ export const videoFrameFromImageDndTarget: DndTarget<VideoFrameFromImageDndTarge
|
||||
},
|
||||
handler: ({ sourceData, dispatch }) => {
|
||||
const { imageDTO } = sourceData.payload;
|
||||
dispatch(startingFrameImageChanged(imageDTOToImageWithDims(imageDTO)));
|
||||
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
|
||||
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
|
||||
import { startingFrameImageChanged } from 'features/parameters/store/videoSlice';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
@@ -13,7 +14,7 @@ export const ContextMenuItemSendToVideo = memo(() => {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(startingFrameImageChanged(imageDTO));
|
||||
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
|
||||
navigationApi.switchToTab('video');
|
||||
}, [imageDTO, dispatch]);
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
|
||||
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -17,7 +17,7 @@ export const ContextMenuItemUseAsRefImage = memo(() => {
|
||||
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
const config = getDefaultRefImageConfig(getState);
|
||||
config.image = imageDTOToImageWithDims(imageDTO);
|
||||
config.image = imageDTOToCroppableImage(imageDTO);
|
||||
dispatch(refImageAdded({ overrides: { config } }));
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
|
||||
@@ -26,7 +26,12 @@ import type {
|
||||
CanvasRasterLayerState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import {
|
||||
imageDTOToCroppableImage,
|
||||
imageDTOToImageObject,
|
||||
imageDTOToImageWithDims,
|
||||
initialControlNet,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
@@ -44,7 +49,7 @@ import { assert } from 'tsafe';
|
||||
|
||||
export const setGlobalReferenceImage = (arg: { imageDTO: ImageDTO; id: string; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, id, dispatch } = arg;
|
||||
dispatch(refImageImageChanged({ id, imageDTO }));
|
||||
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
|
||||
};
|
||||
|
||||
export const setRegionalGuidanceReferenceImage = (arg: {
|
||||
|
||||
@@ -975,7 +975,7 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
|
||||
|
||||
for (const refImage of parsed) {
|
||||
if (refImage.config.image) {
|
||||
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
|
||||
await throwIfImageDoesNotExist(refImage.config.image.original.image.image_name, store);
|
||||
}
|
||||
if (refImage.config.model) {
|
||||
await throwIfModelDoesNotExist(refImage.config.model.key, store);
|
||||
|
||||
@@ -35,7 +35,7 @@ export const LaunchpadForm = memo(() => {
|
||||
return (
|
||||
<Flex flexDir="column" height="100%" gap={3}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={6} p={3}>
|
||||
<Flex flexDir="column" gap={6} py={2}>
|
||||
{/* Welcome Section */}
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Heading size="md">{t('modelManager.launchpad.welcome')}</Heading>
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { Badge, Button, Flex } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
handleInstall: () => void;
|
||||
isInstalled: boolean;
|
||||
};
|
||||
|
||||
export const ModelResultItemActions = memo(({ handleInstall, isInstalled }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex gap={2} shrink={0} pt={1}>
|
||||
{isInstalled ? (
|
||||
// TODO: Add a link button to navigate to model
|
||||
<Badge
|
||||
variant="subtle"
|
||||
colorScheme="green"
|
||||
display="flex"
|
||||
gap={1}
|
||||
alignItems="center"
|
||||
borderRadius="base"
|
||||
h="24px"
|
||||
>
|
||||
<PiCheckBold size="14px" />
|
||||
</Badge>
|
||||
) : (
|
||||
<Button
|
||||
onClick={handleInstall}
|
||||
rightIcon={<PiPlusBold size="14px" />}
|
||||
textTransform="uppercase"
|
||||
letterSpacing="wider"
|
||||
fontSize="9px"
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.install')}
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelResultItemActions.displayName = 'ModelResultItemActions';
|
||||
@@ -1,33 +1,56 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
installModel: (source: string) => void;
|
||||
};
|
||||
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const scanFolderResultItemSx: SystemStyleObject = {
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
w: '100%',
|
||||
py: 2,
|
||||
px: 1,
|
||||
gap: 3,
|
||||
borderBottomWidth: '1px',
|
||||
borderColor: 'base.700',
|
||||
};
|
||||
|
||||
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
const handleInstall = useCallback(() => {
|
||||
installModel(result.path);
|
||||
}, [installModel, result]);
|
||||
|
||||
const modelDisplayName = useMemo(() => {
|
||||
const normalizedPath = result.path.replace(/\\/g, '/').replace(/\/+$/, '');
|
||||
|
||||
// Extract filename/folder name from path
|
||||
const lastSlashIndex = normalizedPath.lastIndexOf('/');
|
||||
return lastSlashIndex === -1 ? normalizedPath : normalizedPath.slice(lastSlashIndex + 1);
|
||||
}, [result.path]);
|
||||
|
||||
const modelPathParts = result.path.split(/[/\\]/);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex sx={scanFolderResultItemSx}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
|
||||
<Text variant="subtext">{result.path}</Text>
|
||||
{/* Model Title */}
|
||||
<Text fontWeight="semibold">{modelDisplayName}</Text>
|
||||
{/* Model Path */}
|
||||
<Flex flexWrap="wrap" color="base.200">
|
||||
{modelPathParts.map((part, index) => (
|
||||
<Text key={index} variant="subtext">
|
||||
{part}
|
||||
{index < modelPathParts.length - 1 && '/'}
|
||||
</Text>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Box>
|
||||
{result.is_installed ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
|
||||
)}
|
||||
</Box>
|
||||
<ModelResultItemActions handleInstall={handleInstall} isInstalled={result.is_installed} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -113,9 +113,9 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex flexDir="column">
|
||||
{filteredResults.map((result) => (
|
||||
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
|
||||
))}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { useStarterBundleInstallStatus } from 'features/modelManagerV2/hooks/use
|
||||
import { t } from 'i18next';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { PiDownloadSimpleBold } from 'react-icons/pi';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterModelBundle'] } & ButtonProps) => {
|
||||
@@ -33,8 +34,16 @@ export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterMod
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button onClick={onClickBundle} isDisabled={install.length === 0} {...rest}>
|
||||
{bundle.name}
|
||||
<Button
|
||||
display="flex"
|
||||
justifyContent="space-between"
|
||||
gap={2}
|
||||
onClick={onClickBundle}
|
||||
isDisabled={install.length === 0}
|
||||
{...rest}
|
||||
>
|
||||
<span>{bundle.name}</span>
|
||||
<PiDownloadSimpleBold size="16px" />
|
||||
</Button>
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { negate } from 'es-toolkit/compat';
|
||||
import { flattenStarterModel, useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { StarterModel } from 'services/api/types';
|
||||
|
||||
const starterModelResultItemSx: SystemStyleObject = {
|
||||
alignItems: 'start',
|
||||
justifyContent: 'space-between',
|
||||
w: '100%',
|
||||
py: 2,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
borderBottomWidth: '1px',
|
||||
borderColor: 'base.700',
|
||||
};
|
||||
|
||||
type Props = {
|
||||
starterModel: StarterModel;
|
||||
};
|
||||
|
||||
export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
|
||||
@@ -40,22 +53,16 @@ export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
|
||||
}, [modelsToInstall, installModel, t]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex sx={starterModelResultItemSx}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Flex gap={3}>
|
||||
<Text fontWeight="semibold">{starterModel.name}</Text>
|
||||
<Text variant="subtext">{starterModel.description}</Text>
|
||||
<Flex gap={1} py={1} alignItems="center">
|
||||
<Badge h="min-content">{starterModel.type.replaceAll('_', ' ')}</Badge>
|
||||
<ModelBaseBadge base={starterModel.base} />
|
||||
<Text fontWeight="semibold">{starterModel.name}</Text>
|
||||
</Flex>
|
||||
<Text variant="subtext">{starterModel.description}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
{isInstalled ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||
)}
|
||||
</Box>
|
||||
<ModelResultItemActions handleInstall={onClick} isInstalled={isInstalled} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -48,9 +48,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} height="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Flex gap={3} direction="column">
|
||||
{size(results.starter_bundles) > 0 && (
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between" p={4} borderWidth="1px" rounded="base">
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text color="base.200" fontWeight="semibold">
|
||||
{t('modelManager.starterBundles')}
|
||||
@@ -73,7 +73,8 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
<InputGroup w={64} size="xs">
|
||||
|
||||
<InputGroup w="100%" size="xs">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm}
|
||||
@@ -96,9 +97,10 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
)}
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||
|
||||
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex flexDir="column">
|
||||
{filteredResults.map((result) => (
|
||||
<StarterModelsResultItem key={result.source} starterModel={result} />
|
||||
))}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiInfoBold } from 'react-icons/pi';
|
||||
import { PiCubeBold, PiFolderOpenBold, PiInfoBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi';
|
||||
import { SiHuggingface } from 'react-icons/si';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
@@ -12,6 +14,12 @@ import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
const installModelsTabSx: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
gap: 2,
|
||||
px: 2,
|
||||
};
|
||||
|
||||
export const InstallModels = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const tabIndex = useStore($installModelsTabIndex);
|
||||
@@ -29,21 +37,36 @@ export const InstallModels = memo(() => {
|
||||
</Button>
|
||||
</Flex>
|
||||
<Tabs
|
||||
variant="collapse"
|
||||
height="50%"
|
||||
variant="line"
|
||||
height="100%"
|
||||
display="flex"
|
||||
flexDir="column"
|
||||
index={tabIndex}
|
||||
onChange={$installModelsTabIndex.set}
|
||||
>
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.launchpadTab')}</Tab>
|
||||
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
||||
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||
<Tab>{t('modelManager.scanFolder')}</Tab>
|
||||
<Tab>{t('modelManager.starterModels')}</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiCubeBold />
|
||||
{t('modelManager.launchpadTab')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiLinkSimpleBold />
|
||||
{t('modelManager.urlOrLocalPath')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<SiHuggingface />
|
||||
{t('modelManager.huggingFace')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiFolderOpenBold />
|
||||
{t('modelManager.scanFolder')}
|
||||
</Tab>
|
||||
<Tab sx={installModelsTabSx}>
|
||||
<PiShootingStarBold />
|
||||
{t('modelManager.starterModels')}
|
||||
</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanels height="100%">
|
||||
<TabPanel height="100%">
|
||||
<LaunchpadForm />
|
||||
</TabPanel>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
@@ -8,6 +9,16 @@ import { PiPlusBold } from 'react-icons/pi';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||
|
||||
const modelManagerSx: SystemStyleObject = {
|
||||
flexDir: 'column',
|
||||
p: 4,
|
||||
gap: 4,
|
||||
borderRadius: 'base',
|
||||
w: '50%',
|
||||
minWidth: '360px',
|
||||
h: 'full',
|
||||
};
|
||||
|
||||
export const ModelManager = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -17,7 +28,7 @@ export const ModelManager = memo(() => {
|
||||
const selectedModelKey = useAppSelector(selectSelectedModelKey);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
|
||||
<Flex sx={modelManagerSx}>
|
||||
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
|
||||
<Heading fontSize="xl" py={1}>
|
||||
{t('common.modelManager')}
|
||||
@@ -28,7 +39,7 @@ export const ModelManager = memo(() => {
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex flexDir="column" layerStyle="second" p={4} gap={4} borderRadius="base" w="full" h="full">
|
||||
<Flex flexDir="column" gap={4} w="full" h="full">
|
||||
<ModelListNavigation />
|
||||
<ModelList />
|
||||
</Flex>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { PiImage } from 'react-icons/pi';
|
||||
@@ -6,19 +7,23 @@ type Props = {
|
||||
image_url?: string | null;
|
||||
};
|
||||
|
||||
export const MODEL_IMAGE_THUMBNAIL_SIZE = '40px';
|
||||
const FALLBACK_ICON_SIZE = '24px';
|
||||
const MODEL_IMAGE_THUMBNAIL_SIZE = '54px';
|
||||
const FALLBACK_ICON_SIZE = '28px';
|
||||
|
||||
const sharedSx: SystemStyleObject = {
|
||||
rounded: 'base',
|
||||
height: MODEL_IMAGE_THUMBNAIL_SIZE,
|
||||
minWidth: MODEL_IMAGE_THUMBNAIL_SIZE,
|
||||
bg: 'base.850',
|
||||
borderWidth: '1px',
|
||||
borderColor: 'base.750',
|
||||
borderStyle: 'solid',
|
||||
};
|
||||
|
||||
const ModelImage = ({ image_url }: Props) => {
|
||||
if (!image_url) {
|
||||
return (
|
||||
<Flex
|
||||
height={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Flex alignItems="center" justifyContent="center" sx={sharedSx}>
|
||||
<Icon color="base.500" as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
|
||||
</Flex>
|
||||
);
|
||||
@@ -29,16 +34,14 @@ const ModelImage = ({ image_url }: Props) => {
|
||||
src={image_url}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
width={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
minHeight={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
borderRadius="base"
|
||||
sx={sharedSx}
|
||||
fallback={
|
||||
<Flex
|
||||
height={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
borderRadius="base"
|
||||
sx={sharedSx}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
|
||||
@@ -1,32 +1,57 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { Flex, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { ModelDeleteButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton';
|
||||
import { filesize } from 'filesize';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelImage, { MODEL_IMAGE_THUMBNAIL_SIZE } from './ModelImage';
|
||||
import ModelImage from './ModelImage';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: AnyModelConfig;
|
||||
};
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
_hover: { bg: 'base.700' },
|
||||
"&[aria-selected='true']": { bg: 'base.700' },
|
||||
paddingInline: 3,
|
||||
paddingBlock: 2,
|
||||
position: 'relative',
|
||||
rounded: 'base',
|
||||
'&:after,&:before': {
|
||||
content: `''`,
|
||||
position: 'absolute',
|
||||
pointerEvents: 'none',
|
||||
},
|
||||
'&:after': {
|
||||
h: '1px',
|
||||
bottom: '-0.5px',
|
||||
insetInline: 3,
|
||||
bg: 'base.850',
|
||||
},
|
||||
'&:before': {
|
||||
left: 1,
|
||||
w: 1,
|
||||
insetBlock: 2,
|
||||
rounded: 'base',
|
||||
},
|
||||
_hover: {
|
||||
bg: 'base.850',
|
||||
'& .delete-button': { opacity: 1 },
|
||||
},
|
||||
'& .delete-button': { opacity: 0 },
|
||||
"&[aria-selected='false']:hover:before": { bg: 'base.750' },
|
||||
"&[aria-selected='true']": {
|
||||
bg: 'base.800',
|
||||
'& .delete-button': { opacity: 1 },
|
||||
},
|
||||
"&[aria-selected='true']:before": { bg: 'invokeBlue.300' },
|
||||
};
|
||||
|
||||
const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectIsSelected = useMemo(
|
||||
() =>
|
||||
@@ -37,58 +62,25 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
[model.key]
|
||||
);
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
dispatch(setSelectedModelKey(model.key));
|
||||
}, [model.key, dispatch]);
|
||||
|
||||
const onClickDeleteButton = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
onOpen();
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
toast({
|
||||
id: 'MODEL_DELETED',
|
||||
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
toast({
|
||||
id: 'MODEL_DELETE_FAILED',
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
});
|
||||
dispatch(setSelectedModelKey(null));
|
||||
}, [deleteModel, model, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={sx}
|
||||
aria-selected={isSelected}
|
||||
justifyContent="flex-start"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
alignItems="center"
|
||||
alignItems="flex-start"
|
||||
gap={2}
|
||||
cursor="pointer"
|
||||
onClick={handleSelectModel}
|
||||
>
|
||||
<Flex gap={2} w="full" h="full" minW={0}>
|
||||
<ModelImage image_url={model.cover_image} />
|
||||
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
|
||||
<Flex alignItems="flex-start" flexDir="column" w="full" minW={0}>
|
||||
<Flex gap={2} w="full" alignItems="flex-start">
|
||||
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
|
||||
{model.name}
|
||||
@@ -101,39 +93,15 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||
<Text variant="subtext" noOfLines={1}>
|
||||
{model.description || 'No Description'}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex
|
||||
h={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
flexDir="column"
|
||||
alignItems="flex-end"
|
||||
justifyContent="space-between"
|
||||
gap={2}
|
||||
>
|
||||
<ModelBaseBadge base={model.base} />
|
||||
<ModelFormatBadge format={model.format} />
|
||||
<Flex gap={1} mt={1}>
|
||||
<ModelBaseBadge base={model.base} />
|
||||
<ModelFormatBadge format={model.format} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<IconButton
|
||||
onClick={onClickDeleteButton}
|
||||
icon={<PiTrashSimpleBold size={16} />}
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
colorScheme="error"
|
||||
h={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
w={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
/>
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
||||
<Text>{t('modelManager.deleteMsg2')}</Text>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
<Flex mt={1}>
|
||||
<ModelDeleteButton modelConfig={model} showLabel={false} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSearchTerm, setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { t } from 'i18next';
|
||||
@@ -25,9 +25,7 @@ export const ModelListNavigation = memo(() => {
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" justifyContent="space-between">
|
||||
<ModelTypeFilter />
|
||||
<Spacer />
|
||||
<InputGroup maxW="400px">
|
||||
<InputGroup>
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm || ''}
|
||||
@@ -47,6 +45,9 @@ export const ModelListNavigation = memo(() => {
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
<Flex shrink={0}>
|
||||
<ModelTypeFilter />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
||||
import { memo } from 'react';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
@@ -9,10 +10,23 @@ type ModelListWrapperProps = {
|
||||
modelList: AnyModelConfig[];
|
||||
};
|
||||
|
||||
const headingSx = {
|
||||
bg: 'base.900',
|
||||
pb: 3,
|
||||
pl: 3,
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
const contentSx = {
|
||||
gap: 0,
|
||||
p: 0,
|
||||
bg: 'base.900',
|
||||
borderRadius: '0',
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
const { title, modelList } = props;
|
||||
return (
|
||||
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
|
||||
<StickyScrollable title={title} contentSx={contentSx} headingSx={headingSx}>
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem key={model.key} model={model} />
|
||||
))}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
@@ -6,13 +7,23 @@ import { memo } from 'react';
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
const modelPaneSx: SystemStyleObject = {
|
||||
layerStyle: 'first',
|
||||
p: 4,
|
||||
borderRadius: 'base',
|
||||
w: {
|
||||
base: '50%',
|
||||
lg: '75%',
|
||||
'2xl': '85%',
|
||||
},
|
||||
h: 'full',
|
||||
minWidth: '300px',
|
||||
overflow: 'auto',
|
||||
};
|
||||
|
||||
export const ModelPane = memo(() => {
|
||||
const selectedModelKey = useAppSelector(selectSelectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
return <Box sx={modelPaneSx}>{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}</Box>;
|
||||
});
|
||||
|
||||
ModelPane.displayName = 'ModelPane';
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
||||
import { dropzoneAccept } from 'common/hooks/useImageUploadButton';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
@@ -8,6 +9,21 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiUploadBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
|
||||
|
||||
const sharedSx: SystemStyleObject = {
|
||||
w: 108,
|
||||
h: 108,
|
||||
fontSize: 36,
|
||||
borderRadius: 'base',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
bg: 'base.800',
|
||||
borderWidth: '1px',
|
||||
borderStyle: 'solid',
|
||||
borderColor: 'base.700',
|
||||
flexShrink: 0,
|
||||
};
|
||||
|
||||
type Props = {
|
||||
model_key: string | null;
|
||||
model_image?: string | null;
|
||||
@@ -86,10 +102,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||
src={image}
|
||||
objectFit="cover"
|
||||
objectPosition="50% 50%"
|
||||
height={108}
|
||||
width={108}
|
||||
minWidth={108}
|
||||
borderRadius="base"
|
||||
sx={sharedSx}
|
||||
/>
|
||||
<IconButton
|
||||
position="absolute"
|
||||
@@ -112,10 +127,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||
variant="ghost"
|
||||
aria-label={t('modelManager.uploadImage')}
|
||||
tooltip={t('modelManager.uploadImage')}
|
||||
w={108}
|
||||
h={108}
|
||||
fontSize={36}
|
||||
icon={<PiUploadBold />}
|
||||
sx={sharedSx}
|
||||
isLoading={request.isLoading}
|
||||
{...getRootProps()}
|
||||
/>
|
||||
|
||||
@@ -52,6 +52,7 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={onOpen}
|
||||
size="sm"
|
||||
aria-label={t('modelManager.convertToDiffusers')}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import { Button, ConfirmationAlertDialog, Flex, IconButton, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, type MouseEvent, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
showLabel?: boolean;
|
||||
modelConfig: AnyModelConfig;
|
||||
};
|
||||
|
||||
export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const onClickDeleteButton = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
onOpen();
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: modelConfig.key })
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
dispatch(setSelectedModelKey(null));
|
||||
toast({
|
||||
id: 'MODEL_DELETED',
|
||||
title: `${t('modelManager.modelDeleted')}: ${modelConfig.name}`,
|
||||
status: 'success',
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
log.error('Error deleting model', error);
|
||||
toast({
|
||||
id: 'MODEL_DELETE_FAILED',
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${modelConfig.name}`,
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
}, [deleteModel, modelConfig.key, modelConfig.name, dispatch, t, log]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{showLabel ? (
|
||||
<Button
|
||||
className="delete-button"
|
||||
size="sm"
|
||||
leftIcon={<PiTrashSimpleBold />}
|
||||
colorScheme="error"
|
||||
onClick={onClickDeleteButton}
|
||||
flexShrink={0}
|
||||
>
|
||||
{t('modelManager.delete')}
|
||||
</Button>
|
||||
) : (
|
||||
<IconButton
|
||||
className="delete-button"
|
||||
onClick={onClickDeleteButton}
|
||||
icon={<PiTrashSimpleBold size={16} />}
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
colorScheme="error"
|
||||
/>
|
||||
)}
|
||||
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
||||
<Text>{t('modelManager.deleteMsg2')}</Text>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ModelDeleteButton.displayName = 'ModelDeleteButton';
|
||||
@@ -24,6 +24,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import { ModelFooter } from './ModelFooter';
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
@@ -158,6 +159,7 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
<ModelFooter modelConfig={modelConfig} isEditing={true} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { ModelConvertButton } from './ModelConvertButton';
|
||||
import { ModelDeleteButton } from './ModelDeleteButton';
|
||||
import { ModelEditButton } from './ModelEditButton';
|
||||
|
||||
const footerRowSx: SystemStyleObject = {
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
gap: 3,
|
||||
'&:not(:last-of-type)': {
|
||||
borderBottomWidth: '1px',
|
||||
borderBottomStyle: 'solid',
|
||||
borderBottomColor: 'border',
|
||||
},
|
||||
p: 3,
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: AnyModelConfig;
|
||||
isEditing: boolean;
|
||||
};
|
||||
|
||||
export const ModelFooter = memo(({ modelConfig, isEditing }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const shouldShowConvertOption = !isEditing && modelConfig.format === 'checkpoint' && modelConfig.type === 'main';
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" borderWidth="1px" borderRadius="base">
|
||||
{shouldShowConvertOption && (
|
||||
<Flex sx={footerRowSx}>
|
||||
<Heading size="sm" color="base.100">
|
||||
{t('modelManager.convertToDiffusers')}
|
||||
</Heading>
|
||||
<Flex py={1}>
|
||||
<ModelConvertButton modelConfig={modelConfig} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
{!isEditing && (
|
||||
<Flex sx={footerRowSx}>
|
||||
<Heading size="sm" color="base.100">
|
||||
{t('modelManager.edit')}
|
||||
</Heading>
|
||||
<Flex py={1}>
|
||||
<ModelEditButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
<Flex sx={footerRowSx}>
|
||||
<Heading size="sm" color="error.200">
|
||||
{t('modelManager.deleteModel')}
|
||||
</Heading>
|
||||
<Flex py={1}>
|
||||
<ModelDeleteButton modelConfig={modelConfig} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ModelFooter.displayName = 'ModelFooter';
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { Box, Divider, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
|
||||
import { LoRAModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/LoRAModelDefaultSettings';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
@@ -12,6 +12,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelFooter } from './ModelFooter';
|
||||
import { RelatedModels } from './RelatedModels';
|
||||
|
||||
type Props = {
|
||||
@@ -39,15 +40,16 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
}, [modelConfig.base, modelConfig.type]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex flexDir="column" gap={4} h="full">
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||
<ModelConvertButton modelConfig={modelConfig} />
|
||||
)}
|
||||
<ModelEditButton />
|
||||
</ModelHeader>
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<Divider />
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Box>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
|
||||
@@ -73,26 +75,33 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
{withSettings && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' ||
|
||||
modelConfig.type === 't2i_adapter' ||
|
||||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
|
||||
{modelConfig.type === 'lora' && (
|
||||
<>
|
||||
<LoRAModelDefaultSettings modelConfig={modelConfig} />
|
||||
<TriggerPhrases modelConfig={modelConfig} />
|
||||
</>
|
||||
)}
|
||||
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
|
||||
</Box>
|
||||
<>
|
||||
<Divider />
|
||||
<Box>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' ||
|
||||
modelConfig.type === 't2i_adapter' ||
|
||||
modelConfig.type === 'control_lora') && (
|
||||
<ControlAdapterModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{modelConfig.type === 'lora' && (
|
||||
<>
|
||||
<LoRAModelDefaultSettings modelConfig={modelConfig} />
|
||||
<TriggerPhrases modelConfig={modelConfig} />
|
||||
</>
|
||||
)}
|
||||
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
<Box overflowY="auto" layerStyle="second" borderRadius="base" p={4}>
|
||||
<Divider />
|
||||
<Box overflowY="auto">
|
||||
<RelatedModels modelConfig={modelConfig} />
|
||||
</Box>
|
||||
</Flex>
|
||||
<ModelFooter modelConfig={modelConfig} isEditing={false} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -74,6 +74,8 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
[addTriggerPhrase]
|
||||
);
|
||||
|
||||
const hasTriggerPhrases = triggerPhrases.length > 0;
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<form onSubmit={onTriggerPhraseAddFormSubmit}>
|
||||
@@ -99,14 +101,16 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index} py={2} px={4} bg="base.700">
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
{hasTriggerPhrases && (
|
||||
<Flex gap="4" flexWrap="wrap">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index} py={2} px={4} bg="base.700">
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
|
||||
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
|
||||
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
|
||||
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
|
||||
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
|
||||
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
|
||||
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
|
||||
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
|
||||
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
|
||||
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
@@ -27,22 +23,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isChatGPT4oModelFieldInputInstance,
|
||||
isChatGPT4oModelFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isCLIPGEmbedModelFieldInputInstance,
|
||||
isCLIPGEmbedModelFieldInputTemplate,
|
||||
isCLIPLEmbedModelFieldInputInstance,
|
||||
isCLIPLEmbedModelFieldInputTemplate,
|
||||
isCogView4MainModelFieldInputInstance,
|
||||
isCogView4MainModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlLoRAModelFieldInputInstance,
|
||||
isControlLoRAModelFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
@@ -51,68 +33,28 @@ import {
|
||||
isFloatFieldInputTemplate,
|
||||
isFloatGeneratorFieldInputInstance,
|
||||
isFloatGeneratorFieldInputTemplate,
|
||||
isFluxKontextModelFieldInputInstance,
|
||||
isFluxKontextModelFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxReduxModelFieldInputInstance,
|
||||
isFluxReduxModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isImageGeneratorFieldInputInstance,
|
||||
isImageGeneratorFieldInputTemplate,
|
||||
isImagen3ModelFieldInputInstance,
|
||||
isImagen3ModelFieldInputTemplate,
|
||||
isImagen4ModelFieldInputInstance,
|
||||
isImagen4ModelFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isIntegerGeneratorFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isLLaVAModelFieldInputInstance,
|
||||
isLLaVAModelFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
isMainModelFieldInputTemplate,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isRunwayModelFieldInputInstance,
|
||||
isRunwayModelFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSD3MainModelFieldInputInstance,
|
||||
isSD3MainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSigLipModelFieldInputInstance,
|
||||
isSigLipModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isStringGeneratorFieldInputInstance,
|
||||
isStringGeneratorFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isT5EncoderModelFieldInputInstance,
|
||||
isT5EncoderModelFieldInputTemplate,
|
||||
isVAEModelFieldInputInstance,
|
||||
isVAEModelFieldInputTemplate,
|
||||
isVeo3ModelFieldInputInstance,
|
||||
isVeo3ModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo } from 'react';
|
||||
@@ -121,33 +63,10 @@ import { assert } from 'tsafe';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import RunwayModelFieldInputComponent from './inputs/RunwayModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
import Veo3ModelFieldInputComponent from './inputs/Veo3ModelFieldInputComponent';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@@ -287,13 +206,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isMainModelFieldInputTemplate(template)) {
|
||||
if (!isMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputTemplate(template)) {
|
||||
if (!isModelIdentifierFieldInputInstance(field)) {
|
||||
return null;
|
||||
@@ -301,159 +213,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputTemplate(template)) {
|
||||
if (!isSDXLRefinerModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isVAEModelFieldInputTemplate(template)) {
|
||||
if (!isVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <VAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isT5EncoderModelFieldInputTemplate(template)) {
|
||||
if (!isT5EncoderModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPLEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPGEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isControlLoRAModelFieldInputTemplate(template)) {
|
||||
if (!isControlLoRAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLLaVAModelFieldInputTemplate(template)) {
|
||||
if (!isLLaVAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputTemplate(template)) {
|
||||
if (!isFluxVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputTemplate(template)) {
|
||||
if (!isLoRAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isControlNetModelFieldInputTemplate(template)) {
|
||||
if (!isControlNetModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isIPAdapterModelFieldInputTemplate(template)) {
|
||||
if (!isIPAdapterModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isT2IAdapterModelFieldInputTemplate(template)) {
|
||||
if (!isT2IAdapterModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSpandrelImageToImageModelFieldInputTemplate(template)) {
|
||||
if (!isSpandrelImageToImageModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSigLipModelFieldInputTemplate(template)) {
|
||||
if (!isSigLipModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxReduxModelFieldInputTemplate(template)) {
|
||||
if (!isFluxReduxModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen3ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen3ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen4ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen4ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxKontextModelFieldInputTemplate(template)) {
|
||||
if (!isFluxKontextModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isChatGPT4oModelFieldInputTemplate(template)) {
|
||||
if (!isChatGPT4oModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isVeo3ModelFieldInputTemplate(template)) {
|
||||
if (!isVeo3ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Veo3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isRunwayModelFieldInputTemplate(template)) {
|
||||
if (!isRunwayModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <RunwayModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isColorFieldInputTemplate(template)) {
|
||||
if (!isColorFieldInputInstance(field)) {
|
||||
return null;
|
||||
@@ -461,34 +220,6 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputTemplate(template)) {
|
||||
if (!isFluxMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputTemplate(template)) {
|
||||
if (!isSD3MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCogView4MainModelFieldInputTemplate(template)) {
|
||||
if (!isCogView4MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputTemplate(template)) {
|
||||
if (!isSDXLMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputTemplate(template)) {
|
||||
if (!isSchedulerFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const onChange = useCallback(
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPEmbedModelFieldInputComponent);
|
||||
@@ -1,45 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: CLIPGEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPGEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPGEmbedModelFieldInputComponent);
|
||||
@@ -1,45 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: CLIPLEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPLEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPLEmbedModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const ChatGPT4oModelFieldInputComponent = (
|
||||
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldChatGPT4oModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ChatGPT4oModelFieldInputComponent);
|
||||
@@ -1,63 +0,0 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
CogView4MainModelFieldInputInstance,
|
||||
CogView4MainModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCogView4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
|
||||
|
||||
const CogView4MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCogView4Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CogView4MainModelFieldInputComponent);
|
||||
@@ -1,48 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
ControlLoRAModelFieldInputInstance,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlLoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
|
||||
|
||||
const ControlLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlLoRAModel();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ControlLoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldControlLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlLoRAModelFieldInputComponent);
|
||||
@@ -1,45 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate>;
|
||||
|
||||
const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ControlNetModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldControlNetModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelFieldInputComponent);
|
||||
@@ -1,49 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FluxKontextModelFieldInputInstance,
|
||||
FluxKontextModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const FluxKontextModelFieldInputComponent = (
|
||||
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxKontextModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxKontextModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxKontextModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
|
||||
|
||||
const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxMainModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { FLUXReduxModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const FluxReduxModelFieldInputComponent = (
|
||||
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxReduxModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: FLUXReduxModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxReduxModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxReduxModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxVAEModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxVAEModelFieldInputComponent);
|
||||
@@ -1,45 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const IPAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: IPAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldIPAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen3ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen3Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen3ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen3ModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen4ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen4Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen4ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen4ModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LlavaOnevisionConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
|
||||
|
||||
const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLLaVAModels();
|
||||
const onChange = useCallback(
|
||||
(value: LlavaOnevisionConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLLaVAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LLaVAModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate>;
|
||||
|
||||
const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const onChange = useCallback(
|
||||
(value: LoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInputTemplate>;
|
||||
|
||||
const MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MainModelFieldInputComponent);
|
||||
@@ -12,7 +12,7 @@ import type { FieldComponentProps } from './types';
|
||||
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
|
||||
|
||||
const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetModelConfigsQuery();
|
||||
const onChange = useCallback(
|
||||
@@ -36,8 +36,31 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}, [data]);
|
||||
if (!fieldTemplate.ui_model_base && !fieldTemplate.ui_model_type) {
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(data).filter((config) => {
|
||||
if (fieldTemplate.ui_model_base && !fieldTemplate.ui_model_base.includes(config.base)) {
|
||||
return false;
|
||||
}
|
||||
if (fieldTemplate.ui_model_type && !fieldTemplate.ui_model_type.includes(config.type)) {
|
||||
return false;
|
||||
}
|
||||
if (
|
||||
fieldTemplate.ui_model_variant &&
|
||||
'variant' in config &&
|
||||
config.variant &&
|
||||
!fieldTemplate.ui_model_variant.includes(config.variant)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}, [data, fieldTemplate]);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefinerModelFieldInputTemplate>;
|
||||
|
||||
const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldRefinerModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldRunwayModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRunwayModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VideoApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const RunwayModelFieldInputComponent = (
|
||||
props: FieldComponentProps<RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useRunwayModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: VideoApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldRunwayModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RunwayModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||
|
||||
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SD3MainModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate>;
|
||||
|
||||
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SDXLMainModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSigLipModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SigLipModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SigLipModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSigLipModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: SigLipModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSigLipModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SigLipModelFieldInputComponent);
|
||||
@@ -1,49 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SpandrelImageToImageModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSpandrelImageToImageModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const T2IAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: T2IAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldT2IAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterModelFieldInputComponent);
|
||||
@@ -1,43 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
|
||||
|
||||
const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
const onChange = useCallback(
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldT5EncoderValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T5EncoderModelFieldInputComponent);
|
||||
@@ -1,44 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputTemplate>;
|
||||
|
||||
const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldVaeModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VAEModelFieldInputComponent);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user