mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 11:57:55 -05:00
Compare commits
133 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
143487a492 | ||
|
|
203fa04295 | ||
|
|
954fce3c67 | ||
|
|
821889148a | ||
|
|
4c248d8c2c | ||
|
|
deb75805d4 | ||
|
|
93110654da | ||
|
|
ff0c48d532 | ||
|
|
de18073814 | ||
|
|
0708af9545 | ||
|
|
1e85184c62 | ||
|
|
11d3b8d944 | ||
|
|
bffd4afb96 | ||
|
|
518a896521 | ||
|
|
2647ff141a | ||
|
|
ba0bac2aa5 | ||
|
|
862e2a3e49 | ||
|
|
d22fd32b05 | ||
|
|
391e5b7f8c | ||
|
|
c9d2a5f59a | ||
|
|
1f63b60021 | ||
|
|
a499b9f54e | ||
|
|
104505ea02 | ||
|
|
ee4002607c | ||
|
|
fd20582cdd | ||
|
|
43b0d07517 | ||
|
|
f83592a052 | ||
|
|
b3ee906749 | ||
|
|
5d69e9068a | ||
|
|
a79136b058 | ||
|
|
944af4d4a9 | ||
|
|
5e001be73a | ||
|
|
576a644b3a | ||
|
|
703557c8a6 | ||
|
|
d59a53b3f9 | ||
|
|
7b8f78c2d9 | ||
|
|
31ab9be79a | ||
|
|
5011fab85d | ||
|
|
92bdb9fdcc | ||
|
|
548e766c0b | ||
|
|
ff897f74a1 | ||
|
|
3d29c996ed | ||
|
|
42d57d1225 | ||
|
|
193fa9395a | ||
|
|
56cd839d5b | ||
|
|
7b446ee40d | ||
|
|
17027c4070 | ||
|
|
13d44f47ce | ||
|
|
550fbdeb1c | ||
|
|
a01cd7c497 | ||
|
|
c54afd600c | ||
|
|
4f911a0ea8 | ||
|
|
fb91f48722 | ||
|
|
69db60a614 | ||
|
|
c6d7f951aa | ||
|
|
04c005284c | ||
|
|
2d7f9697bf | ||
|
|
ae530492a2 | ||
|
|
87ed1e3b6d | ||
|
|
cc54466db9 | ||
|
|
cbdafe7e38 | ||
|
|
112cb76174 | ||
|
|
e56d41ab99 | ||
|
|
273dfd86ab | ||
|
|
871271fde5 | ||
|
|
14944872c4 | ||
|
|
07bcf3c446 | ||
|
|
8ed5585285 | ||
|
|
5ce226a467 | ||
|
|
c64f20a72b | ||
|
|
0c9c10a03a | ||
|
|
4a0df6b865 | ||
|
|
ba165572bf | ||
|
|
c3d6a10603 | ||
|
|
4efc86299d | ||
|
|
e8c7cf63fd | ||
|
|
698b034190 | ||
|
|
3988128c40 | ||
|
|
c768f47365 | ||
|
|
19a63abc54 | ||
|
|
75ec36bf9a | ||
|
|
d802f8e7fb | ||
|
|
6873e0308d | ||
|
|
66eb73088e | ||
|
|
ed81a13eb4 | ||
|
|
fbc1aae52d | ||
|
|
ba42c3e63f | ||
|
|
b24e820aa0 | ||
|
|
e8f6b3b77a | ||
|
|
8f13518c97 | ||
|
|
6afbc12074 | ||
|
|
6b0a56ceb9 | ||
|
|
ca92497e52 | ||
|
|
97d45ceaf2 | ||
|
|
aeb3841a6f | ||
|
|
c14d33d3c1 | ||
|
|
676e59e072 | ||
|
|
e7dcb6a03f | ||
|
|
fb95b7cc2b | ||
|
|
015dc3ac0d | ||
|
|
9d8a71b362 | ||
|
|
2eb212f393 | ||
|
|
34b268c15c | ||
|
|
9a203a64dc | ||
|
|
d80004e056 | ||
|
|
de32ed23a7 | ||
|
|
5aed2b315d | ||
|
|
48db6cfc4f | ||
|
|
aa7c5c281a | ||
|
|
87aeb7f889 | ||
|
|
3b3d6e413a | ||
|
|
b6432f2de3 | ||
|
|
9d0a28ccae | ||
|
|
c3bf0a3277 | ||
|
|
b516610c1e | ||
|
|
677e717cd7 | ||
|
|
c52584e057 | ||
|
|
b6767441db | ||
|
|
8745dbe67d | ||
|
|
a565d9473e | ||
|
|
4dbf07c3e0 | ||
|
|
f6eb4d9a6b | ||
|
|
5037967b82 | ||
|
|
4930ba48ce | ||
|
|
40d2092256 | ||
|
|
d2e9237740 | ||
|
|
b191b706c1 | ||
|
|
4d0f760ec8 | ||
|
|
65cda5365a | ||
|
|
1f2d1d086f | ||
|
|
418f3c3f19 | ||
|
|
72173e284c | ||
|
|
9cc13556aa |
@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||
folder for more examples:
|
||||
|
||||
```py
|
||||
from .cool_node import CoolInvocation
|
||||
from .cool_node import ResizeInvocation
|
||||
```
|
||||
|
||||
## Creating A New Invocation
|
||||
@@ -69,7 +69,10 @@ The first set of things we need to do when creating a new Invocation are -
|
||||
So let us do that.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -103,8 +106,12 @@ create your own custom field types later in this guide. For now, let's go ahead
|
||||
and use it.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -128,8 +135,12 @@ image: ImageField = InputField(description="The input image")
|
||||
Great. Now let us create our other inputs for `width` and `height`
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -163,8 +174,13 @@ that are provided by it by InvokeAI.
|
||||
Let us create this function first.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -191,8 +207,14 @@ all the necessary info related to image outputs. So let us use that.
|
||||
We will cover how to create your own output types later in this guide.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation('resize')
|
||||
@@ -217,9 +239,15 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
||||
So let's do that.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation("resize")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
|
||||
@@ -893,6 +893,12 @@ class HFTokenHelper:
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
@classmethod
|
||||
def reset_token(cls) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.logout()
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
@@ -915,3 +921,8 @@ async def do_hf_login(
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
|
||||
async def reset_hf_token() -> HFTokenStatus:
|
||||
return HFTokenHelper.reset_token()
|
||||
|
||||
@@ -25,7 +25,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -72,13 +72,24 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Special = "special"
|
||||
|
||||
|
||||
class Bottleneck(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The bottleneck of an invocation.
|
||||
- `Network`: The invocation's execution is network-bound.
|
||||
- `GPU`: The invocation's execution is GPU-bound.
|
||||
"""
|
||||
|
||||
Network = "network"
|
||||
GPU = "gpu"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @invocation decorator logic. Do not use this directly.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: str = Field(
|
||||
@@ -100,6 +111,12 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
output_meta: Optional[dict[str, JsonValue]] = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
@@ -235,6 +252,8 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
bottleneck: ClassVar[Bottleneck]
|
||||
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -256,6 +275,26 @@ class InvocationRegistry:
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
|
||||
invocation_type = invocation.get_type()
|
||||
node_pack = invocation.UIConfig.node_pack
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
if clobbered_invocation is not None:
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The invocation being clobbered is a core invocation
|
||||
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
|
||||
else:
|
||||
# The invocation being clobbered is a custom invocation
|
||||
logger.warning(
|
||||
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
|
||||
)
|
||||
cls._invocation_classes.remove(clobbered_invocation)
|
||||
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@@ -314,6 +353,15 @@ class InvocationRegistry:
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
output_type = output.get_type()
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
|
||||
if clobbered_output is not None:
|
||||
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
|
||||
logger.warning(f'Overriding invocation output "{output_type}"')
|
||||
cls._output_classes.remove(clobbered_output)
|
||||
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@@ -322,6 +370,11 @@ class InvocationRegistry:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
|
||||
"""Gets a map of all output types to their output classes."""
|
||||
return {i.get_type(): i for i in cls.get_output_classes()}
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
@@ -347,6 +400,11 @@ class InvocationRegistry:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
@classmethod
|
||||
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
|
||||
"""Gets the output class for a given output type."""
|
||||
return cls.get_outputs_map().get(output_type)
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
@@ -354,11 +412,12 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
"bottleneck",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
@@ -438,6 +497,7 @@ def invocation(
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
classification: Classification = Classification.Stable,
|
||||
bottleneck: Bottleneck = Bottleneck.GPU,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Registers an invocation.
|
||||
@@ -449,6 +509,7 @@ def invocation(
|
||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
|
||||
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
@@ -460,25 +521,6 @@ def invocation(
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
|
||||
# Handle the case where an existing node is being clobbered by the one we are registering
|
||||
if invocation_type in InvocationRegistry.get_invocation_types():
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
assert clobbered_invocation is not None
|
||||
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The node being clobbered is a core node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
|
||||
)
|
||||
else:
|
||||
# The node being clobbered is a custom node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
|
||||
)
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
@@ -504,6 +546,8 @@ def invocation(
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
cls.bottleneck = bottleneck
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
@@ -572,13 +616,9 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in InvocationRegistry.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
|
||||
@@ -61,6 +61,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
|
||||
@@ -241,6 +241,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
@@ -263,6 +264,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
credits=queue_item.credits,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -257,6 +257,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
|
||||
@@ -61,6 +61,10 @@ def get_openapi_func(
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in InvocationRegistry.get_output_classes():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
# Remove output_metadata that is only used on back-end from the schema
|
||||
if "output_meta" in json_schema["properties"]:
|
||||
json_schema["properties"].pop("output_meta")
|
||||
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def get_timestamp() -> int:
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
return datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
|
||||
@@ -144,6 +144,7 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
|
||||
|
||||
_USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
_USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
@@ -600,6 +601,21 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
}
|
||||
|
||||
|
||||
class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
||||
"""Model config for API-based models."""
|
||||
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
raise NotImplementedError("API models are not parsed from disk.")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@@ -667,6 +683,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
|
||||
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
|
||||
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
|
||||
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -13,6 +13,12 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
# up matrix and down matrix have different ranks so we can't simply multiply them
|
||||
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
|
||||
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
|
||||
@@ -26,7 +26,8 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
CogView4 = "cogview4"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
Imagen3 = "imagen3"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@@ -98,6 +99,7 @@ class ModelFormat(str, Enum):
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
Api = "api"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
|
||||
@@ -19,6 +19,7 @@ class LoRALayer(LoRALayerBase):
|
||||
self.up = up
|
||||
self.mid = mid
|
||||
self.down = down
|
||||
self.are_ranks_equal = up.shape[1] == down.shape[0]
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
@@ -58,12 +59,42 @@ class LoRALayer(LoRALayerBase):
|
||||
def _rank(self) -> int:
|
||||
return self.down.shape[0]
|
||||
|
||||
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
|
||||
|
||||
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
|
||||
the LoRA weights have different ranks. This function handles the fusion of these differently sized
|
||||
matrices.
|
||||
"""
|
||||
|
||||
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
|
||||
rank_diff = down.shape[0] / up.shape[1]
|
||||
|
||||
if rank_diff > 1:
|
||||
rank_diff = down.shape[0] / up.shape[1]
|
||||
w_down = down.chunk(int(rank_diff), dim=0)
|
||||
for w_down_chunk in w_down:
|
||||
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
|
||||
else:
|
||||
rank_diff = up.shape[1] / down.shape[0]
|
||||
w_up = up.chunk(int(rank_diff), dim=0)
|
||||
for w_up_chunk in w_up:
|
||||
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
|
||||
|
||||
return fused_lora
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
# up matrix and down matrix have different ranks so we can't simply multiply them
|
||||
if not self.are_ranks_equal:
|
||||
weight = self.fuse_weights(self.up, self.down)
|
||||
return weight
|
||||
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
@@ -20,6 +20,14 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_final_layer_linear.alpha
|
||||
# lora_unet_final_layer_linear.lora_down.weight
|
||||
# lora_unet_final_layer_linear.lora_up.weight
|
||||
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
|
||||
|
||||
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
|
||||
@@ -44,6 +52,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
@@ -65,6 +74,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_unet"):
|
||||
# Skip the final layer. This is incompatible with current model definition.
|
||||
if layer_name.startswith("lora_unet_final_layer"):
|
||||
continue
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
clip_grouped_sd[layer_name] = layer_state_dict
|
||||
|
||||
@@ -52,68 +52,68 @@
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@dagrejs/graphlib": "^2.2.4",
|
||||
"@fontsource-variable/inter": "^5.1.0",
|
||||
"@fontsource-variable/inter": "^5.2.5",
|
||||
"@invoke-ai/ui-library": "^0.0.46",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.6.1",
|
||||
"@nanostores/react": "^1.0.0",
|
||||
"@reduxjs/toolkit": "2.7.0",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.5.3",
|
||||
"@xyflow/react": "^12.6.0",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.0.0",
|
||||
"cmdk": "^1.1.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"filesize": "^10.1.6",
|
||||
"fracturedjsonjs": "^4.0.2",
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^23.15.1",
|
||||
"i18next-http-backend": "^2.6.1",
|
||||
"i18next": "^25.0.1",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "^6.2.1",
|
||||
"jsondiffpatch": "^0.6.0",
|
||||
"konva": "^9.3.15",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.20",
|
||||
"linkify-react": "^4.2.0",
|
||||
"linkifyjs": "^4.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lru-cache": "^11.0.1",
|
||||
"lru-cache": "^11.1.0",
|
||||
"mtwist": "^1.0.2",
|
||||
"nanoid": "^5.0.7",
|
||||
"nanostores": "^0.11.3",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
"overlayscrollbars": "^2.10.0",
|
||||
"nanoid": "^5.1.5",
|
||||
"nanostores": "^1.0.1",
|
||||
"new-github-issue-url": "^1.1.0",
|
||||
"overlayscrollbars": "^2.11.1",
|
||||
"overlayscrollbars-react": "^0.5.6",
|
||||
"perfect-freehand": "^1.2.2",
|
||||
"query-string": "^9.1.0",
|
||||
"query-string": "^9.1.1",
|
||||
"raf-throttle": "^2.0.6",
|
||||
"react": "^18.3.1",
|
||||
"react-colorful": "^5.6.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.2.9",
|
||||
"react-error-boundary": "^4.0.13",
|
||||
"react-hook-form": "^7.53.0",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-error-boundary": "^5.0.0",
|
||||
"react-hook-form": "^7.56.1",
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^15.0.2",
|
||||
"react-icons": "^5.3.0",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.1.4",
|
||||
"react-textarea-autosize": "^8.5.7",
|
||||
"react-use": "^17.5.1",
|
||||
"react-virtuoso": "^4.12.5",
|
||||
"react-i18next": "^15.5.1",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-redux": "9.2.0",
|
||||
"react-resizable-panels": "^2.1.8",
|
||||
"react-textarea-autosize": "^8.5.9",
|
||||
"react-use": "^17.6.0",
|
||||
"react-virtuoso": "^4.12.6",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"redux-remember": "^5.2.0",
|
||||
"redux-undo": "^1.1.0",
|
||||
"rfdc": "^1.4.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.8.0",
|
||||
"stable-hash": "^0.0.4",
|
||||
"use-debounce": "^10.0.3",
|
||||
"serialize-error": "^12.0.0",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"stable-hash": "^0.0.5",
|
||||
"use-debounce": "^10.0.4",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^10.0.0",
|
||||
"zod": "^3.23.8",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^3.24.3",
|
||||
"zod-validation-error": "^3.4.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -123,43 +123,43 @@
|
||||
"devDependencies": {
|
||||
"@invoke-ai/eslint-config-react": "^0.0.14",
|
||||
"@invoke-ai/prettier-config-react": "^0.0.7",
|
||||
"@storybook/addon-essentials": "^8.3.4",
|
||||
"@storybook/addon-interactions": "^8.3.4",
|
||||
"@storybook/addon-links": "^8.3.4",
|
||||
"@storybook/addon-storysource": "^8.3.4",
|
||||
"@storybook/manager-api": "^8.3.4",
|
||||
"@storybook/react": "^8.3.4",
|
||||
"@storybook/react-vite": "^8.5.5",
|
||||
"@storybook/theming": "^8.3.4",
|
||||
"@storybook/addon-essentials": "^8.6.12",
|
||||
"@storybook/addon-interactions": "^8.6.12",
|
||||
"@storybook/addon-links": "^8.6.12",
|
||||
"@storybook/addon-storysource": "^8.6.12",
|
||||
"@storybook/manager-api": "^8.6.12",
|
||||
"@storybook/react": "^8.6.12",
|
||||
"@storybook/react-vite": "^8.6.12",
|
||||
"@storybook/theming": "^8.6.12",
|
||||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/node": "^20.16.10",
|
||||
"@types/node": "^22.15.1",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^10.0.0",
|
||||
"@vitejs/plugin-react-swc": "^3.8.0",
|
||||
"@vitest/coverage-v8": "^3.0.6",
|
||||
"@vitest/ui": "^3.0.6",
|
||||
"concurrently": "^8.2.2",
|
||||
"@vitejs/plugin-react-swc": "^3.9.0",
|
||||
"@vitest/coverage-v8": "^3.1.2",
|
||||
"@vitest/ui": "^3.1.2",
|
||||
"concurrently": "^9.1.2",
|
||||
"csstype": "^3.1.3",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-plugin-i18next": "^6.1.0",
|
||||
"eslint-plugin-i18next": "^6.1.1",
|
||||
"eslint-plugin-path": "^1.3.0",
|
||||
"knip": "^5.31.0",
|
||||
"knip": "^5.50.5",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.4.1",
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"openapi-typescript": "^7.6.1",
|
||||
"prettier": "^3.5.3",
|
||||
"rollup-plugin-visualizer": "^5.14.0",
|
||||
"storybook": "^8.6.12",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^6.1.0",
|
||||
"type-fest": "^4.40.0",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^6.3.3",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
"vite-plugin-dts": "^4.5.0",
|
||||
"vite-plugin-dts": "^4.5.3",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^5.1.4",
|
||||
"vitest": "^3.0.6"
|
||||
"vitest": "^3.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
|
||||
3713
invokeai/frontend/web/pnpm-lock.yaml
generated
3713
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -118,6 +118,8 @@
|
||||
"error": "Error",
|
||||
"error_withCount_one": "{{count}} error",
|
||||
"error_withCount_other": "{{count}} errors",
|
||||
"model_withCount_one": "{{count}} model",
|
||||
"model_withCount_other": "{{count}} models",
|
||||
"file": "File",
|
||||
"folder": "Folder",
|
||||
"format": "format",
|
||||
@@ -138,6 +140,8 @@
|
||||
"localSystem": "Local System",
|
||||
"learnMore": "Learn More",
|
||||
"modelManager": "Model Manager",
|
||||
"noMatches": "No matches",
|
||||
"noOptions": "No options",
|
||||
"nodes": "Workflows",
|
||||
"notInstalled": "Not $t(common.installed)",
|
||||
"openInNewTab": "Open in New Tab",
|
||||
@@ -171,6 +175,8 @@
|
||||
"blue": "Blue",
|
||||
"alpha": "Alpha",
|
||||
"selected": "Selected",
|
||||
"search": "Search",
|
||||
"clear": "Clear",
|
||||
"tab": "Tab",
|
||||
"view": "View",
|
||||
"edit": "Edit",
|
||||
@@ -197,7 +203,11 @@
|
||||
"column": "Column",
|
||||
"value": "Value",
|
||||
"label": "Label",
|
||||
"systemInformation": "System Information"
|
||||
"systemInformation": "System Information",
|
||||
"compactView": "Compact View",
|
||||
"fullView": "Full View",
|
||||
"options_withCount_one": "{{count}} option",
|
||||
"options_withCount_other": "{{count}} options"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -258,6 +268,7 @@
|
||||
"status": "Status",
|
||||
"total": "Total",
|
||||
"time": "Time",
|
||||
"credits": "Credits",
|
||||
"pending": "Pending",
|
||||
"in_progress": "In Progress",
|
||||
"completed": "Completed",
|
||||
@@ -768,6 +779,7 @@
|
||||
"description": "Description",
|
||||
"edit": "Edit",
|
||||
"fileSize": "File Size",
|
||||
"filterModels": "Filter models",
|
||||
"fluxRedux": "FLUX Redux",
|
||||
"height": "Height",
|
||||
"huggingFace": "HuggingFace",
|
||||
@@ -787,6 +799,7 @@
|
||||
"hfTokenUnableToVerify": "Unable to Verify HF Token",
|
||||
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
|
||||
"hfTokenSaved": "HF Token Saved",
|
||||
"hfTokenReset": "HF Token Reset",
|
||||
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
|
||||
"urlUnauthorizedErrorMessage2": "Learn how here.",
|
||||
"imageEncoderModelId": "Image Encoder Model ID",
|
||||
@@ -821,10 +834,12 @@
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"name": "Name",
|
||||
"noModelsInstalled": "No Models Installed",
|
||||
"modelPickerFallbackNoModelsInstalled": "No models installed.",
|
||||
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
|
||||
"noModelsInstalledDesc1": "Install models with the",
|
||||
"noModelSelected": "No Model Selected",
|
||||
"noMatchingModels": "No matching Models",
|
||||
"noMatchingModels": "No matching models",
|
||||
"noModelsInstalled": "No models installed",
|
||||
"none": "none",
|
||||
"path": "Path",
|
||||
"pathToConfig": "Path To Config",
|
||||
@@ -871,7 +886,8 @@
|
||||
"installingXModels_one": "Installing {{count}} model",
|
||||
"installingXModels_other": "Installing {{count}} models",
|
||||
"skippingXDuplicates_one": ", skipping {{count}} duplicate",
|
||||
"skippingXDuplicates_other": ", skipping {{count}} duplicates"
|
||||
"skippingXDuplicates_other": ", skipping {{count}} duplicates",
|
||||
"manageModels": "Manage Models"
|
||||
},
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
@@ -1093,6 +1109,7 @@
|
||||
"info": "Info",
|
||||
"invoke": {
|
||||
"addingImagesTo": "Adding images to",
|
||||
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
|
||||
"invoke": "Invoke",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"missingInputForField": "missing input",
|
||||
@@ -1173,7 +1190,8 @@
|
||||
"width": "Width",
|
||||
"gaussianBlur": "Gaussian Blur",
|
||||
"boxBlur": "Box Blur",
|
||||
"staged": "Staged"
|
||||
"staged": "Staged",
|
||||
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your <LinkComponent>account settings</LinkComponent> to upgrade."
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"showDynamicPrompts": "Show Dynamic Prompts",
|
||||
@@ -1312,6 +1330,8 @@
|
||||
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
|
||||
"unableToCopyDesc_theseSteps": "these steps",
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
|
||||
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
|
||||
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
|
||||
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
|
||||
"workflowUnpublished": "Workflow Unpublished"
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
|
||||
import type { LoggingOverrides } from 'app/logging/logger';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
|
||||
@@ -12,10 +13,13 @@ import type { CustomStarUi } from 'app/store/nanostores/customStarUI';
|
||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||
import { $logo } from 'app/store/nanostores/logo';
|
||||
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
|
||||
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
|
||||
import { $projectId, $projectName, $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import { $toastMap } from 'app/store/nanostores/toastMap';
|
||||
import { $whatsNew } from 'app/store/nanostores/whatsNew';
|
||||
import { createStore } from 'app/store/store';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
@@ -29,6 +33,7 @@ import {
|
||||
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import type { ToastConfig } from 'features/toast/toast';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
@@ -45,6 +50,7 @@ interface Props extends PropsWithChildren {
|
||||
token?: string;
|
||||
config?: PartialAppConfig;
|
||||
customNavComponent?: ReactNode;
|
||||
accountSettingsLink?: string;
|
||||
middleware?: Middleware[];
|
||||
projectId?: string;
|
||||
projectName?: string;
|
||||
@@ -55,10 +61,16 @@ interface Props extends PropsWithChildren {
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
isDebugging?: boolean;
|
||||
logo?: ReactNode;
|
||||
toastMap?: Record<string, ToastConfig>;
|
||||
whatsNew?: ReactNode[];
|
||||
workflowCategories?: WorkflowCategory[];
|
||||
workflowTagCategories?: WorkflowTagCategory[];
|
||||
workflowSortOptions?: WorkflowSortOption[];
|
||||
loggingOverrides?: LoggingOverrides;
|
||||
/**
|
||||
* If provided, overrides in-app navigation to the model manager
|
||||
*/
|
||||
onClickGoToModelManager?: () => void;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({
|
||||
@@ -67,6 +79,7 @@ const InvokeAIUI = ({
|
||||
token,
|
||||
config,
|
||||
customNavComponent,
|
||||
accountSettingsLink,
|
||||
middleware,
|
||||
projectId,
|
||||
projectName,
|
||||
@@ -77,10 +90,13 @@ const InvokeAIUI = ({
|
||||
socketOptions,
|
||||
isDebugging = false,
|
||||
logo,
|
||||
toastMap,
|
||||
workflowCategories,
|
||||
workflowTagCategories,
|
||||
workflowSortOptions,
|
||||
loggingOverrides,
|
||||
onClickGoToModelManager,
|
||||
whatsNew,
|
||||
}: Props) => {
|
||||
useLayoutEffect(() => {
|
||||
/*
|
||||
@@ -169,6 +185,16 @@ const InvokeAIUI = ({
|
||||
};
|
||||
}, [customNavComponent]);
|
||||
|
||||
useEffect(() => {
|
||||
if (accountSettingsLink) {
|
||||
$accountSettingsLink.set(accountSettingsLink);
|
||||
}
|
||||
|
||||
return () => {
|
||||
$accountSettingsLink.set(undefined);
|
||||
};
|
||||
}, [accountSettingsLink]);
|
||||
|
||||
useEffect(() => {
|
||||
if (openAPISchemaUrl) {
|
||||
$openAPISchemaUrl.set(openAPISchemaUrl);
|
||||
@@ -205,6 +231,36 @@ const InvokeAIUI = ({
|
||||
};
|
||||
}, [logo]);
|
||||
|
||||
useEffect(() => {
|
||||
if (toastMap) {
|
||||
$toastMap.set(toastMap);
|
||||
}
|
||||
|
||||
return () => {
|
||||
$toastMap.set(undefined);
|
||||
};
|
||||
}, [toastMap]);
|
||||
|
||||
useEffect(() => {
|
||||
if (whatsNew) {
|
||||
$whatsNew.set(whatsNew);
|
||||
}
|
||||
|
||||
return () => {
|
||||
$whatsNew.set(undefined);
|
||||
};
|
||||
}, [whatsNew]);
|
||||
|
||||
useEffect(() => {
|
||||
if (onClickGoToModelManager) {
|
||||
$onClickGoToModelManager.set(onClickGoToModelManager);
|
||||
}
|
||||
|
||||
return () => {
|
||||
$onClickGoToModelManager.set(undefined);
|
||||
};
|
||||
}, [onClickGoToModelManager]);
|
||||
|
||||
useEffect(() => {
|
||||
if (workflowCategories) {
|
||||
$workflowLibraryCategoriesOptions.set(workflowCategories);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
@@ -6,11 +7,14 @@ import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
@@ -48,32 +52,50 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
return await buildFLUXGraph(state, manager);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, manager);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, manager);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, manager);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status: 'error',
|
||||
title: 'Failed to build graph',
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, noise, posCond } = buildGraphResult.value;
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
@@ -89,7 +111,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
@@ -18,16 +18,24 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
export const $accountSettingsLink = atom<string | undefined>(undefined);
|
||||
@@ -0,0 +1,3 @@
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
export const $onClickGoToModelManager = atom<(() => void) | undefined>(undefined);
|
||||
@@ -0,0 +1,4 @@
|
||||
import type { ToastConfig } from 'features/toast/toast';
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
export const $toastMap = atom<Record<string, ToastConfig> | undefined>(undefined);
|
||||
@@ -0,0 +1,4 @@
|
||||
import { atom } from 'nanostores';
|
||||
import type { ReactNode } from 'react';
|
||||
|
||||
export const $whatsNew = atom<ReactNode[] | undefined>(undefined);
|
||||
@@ -145,7 +145,10 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
);
|
||||
return transformed;
|
||||
} catch (err) {
|
||||
log.warn({ error: serializeError(err) }, `Error rehydrating slice "${key}", falling back to default initial state`);
|
||||
log.warn(
|
||||
{ error: serializeError(err as Error) },
|
||||
`Error rehydrating slice "${key}", falling back to default initial state`
|
||||
);
|
||||
return persistConfig.initialState;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -28,7 +28,8 @@ export type AppFeature =
|
||||
| 'starterModels'
|
||||
| 'hfToken'
|
||||
| 'retryQueueItem'
|
||||
| 'cancelAndClearAll';
|
||||
| 'cancelAndClearAll'
|
||||
| 'chatGPT4oModels';
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
@@ -83,6 +84,7 @@ export type AppConfig = {
|
||||
metadataFetchDebounce?: number;
|
||||
workflowFetchDebounce?: number;
|
||||
isLocal?: boolean;
|
||||
shouldShowCredits: boolean;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: string[];
|
||||
|
||||
1092
invokeai/frontend/web/src/common/components/Picker/Picker.tsx
Normal file
1092
invokeai/frontend/web/src/common/components/Picker/Picker.tsx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
||||
}, [optionsFilter, getIsDisabled, modelConfigs, shouldShowModelDescriptions]);
|
||||
|
||||
const value = useMemo(
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
|
||||
[options, selectedModel]
|
||||
);
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { memo } from 'react';
|
||||
|
||||
/**
|
||||
* A typed version of React.memo, useful for components that take generics.
|
||||
*/
|
||||
export const typedMemo: <T>(c: T) => T = memo;
|
||||
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
|
||||
component: T,
|
||||
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
|
||||
) => T & { displayName?: string } = memo;
|
||||
|
||||
@@ -24,6 +24,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
|
||||
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
|
||||
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
|
||||
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" justifyContent="center" gap={4}>
|
||||
@@ -52,6 +53,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addInpaintMask}
|
||||
isDisabled={!isInpaintLayerEnabled}
|
||||
>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
|
||||
@@ -25,6 +25,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
|
||||
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
|
||||
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
|
||||
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
@@ -46,7 +47,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.regional')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={!isRegionalGuidanceEnabled}>
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
|
||||
};
|
||||
|
||||
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
onChangeModel(modelConfig);
|
||||
},
|
||||
[onChangeModel]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
const hasSameBase = currentBaseModel === model.base;
|
||||
return !hasMainModel || !hasSameBase;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChangeModel,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';
|
||||
@@ -61,7 +61,7 @@ export const IPAdapterImagePreview = memo(
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" />
|
||||
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
|
||||
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
|
||||
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold } from 'react-icons/pi';
|
||||
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
|
||||
import { IPAdapterModel } from './IPAdapterModel';
|
||||
|
||||
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
|
||||
createSelector(
|
||||
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={false}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
|
||||
@@ -4,29 +4,26 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
isRegionalGuidance: boolean;
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
|
||||
};
|
||||
|
||||
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
|
||||
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
|
||||
if (config.base === 'flux' && config.type === 'ip_adapter') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const filter = useCallback(
|
||||
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
|
||||
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[isRegionalGuidance]
|
||||
);
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
|
||||
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
@@ -71,4 +68,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
|
||||
);
|
||||
});
|
||||
|
||||
IPAdapterModel.displayName = 'IPAdapterModel';
|
||||
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';
|
||||
@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
|
||||
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
|
||||
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
|
||||
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
|
||||
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={true}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
|
||||
@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
|
||||
import {
|
||||
initialChatGPT4oReferenceImage,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
modelConfigsAdapterSelectors,
|
||||
selectMainModelConfig,
|
||||
selectModelConfigsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
export const selectDefaultRefImageConfig = createSelector(
|
||||
selectMainModelConfig,
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
|
||||
if (selectedMainModel?.base === 'chatgpt-4o') {
|
||||
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
|
||||
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
|
||||
return referenceImage;
|
||||
}
|
||||
|
||||
const { data } = query;
|
||||
let model: IPAdapterModelConfig | null = null;
|
||||
if (data) {
|
||||
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
|
||||
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
|
||||
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
|
||||
}
|
||||
const ipAdapter = deepClone(initialIPAdapter);
|
||||
if (model) {
|
||||
ipAdapter.model = zModelIdentifierField.parse(model);
|
||||
if (model.base === 'flux') {
|
||||
ipAdapter.clipVisionModel = 'ViT-L';
|
||||
}
|
||||
}
|
||||
return ipAdapter;
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Selects the default IP adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
|
||||
|
||||
export const useAddGlobalReferenceImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
|
||||
const overrides = { ipAdapter: deepClone(defaultRefImage) };
|
||||
dispatch(referenceImageAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
}, [defaultRefImage, dispatch]);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
@@ -41,7 +41,7 @@ export const useCopyLayerToClipboard = () => {
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
|
||||
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('toast.problemCopyingLayer'),
|
||||
@@ -82,7 +82,7 @@ export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
|
||||
toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
|
||||
});
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to save canvas to gallery');
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to save canvas to gallery');
|
||||
toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
|
||||
}
|
||||
}, [canvasManager.compositor, canvasManager.stateApi, clipboard, region, t]);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHook
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
@@ -198,7 +198,7 @@ export const useNewRegionalReferenceImageFromBbox = () => {
|
||||
export const useNewGlobalReferenceImageFromBbox = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultRefImageConfig);
|
||||
|
||||
const arg = useMemo<UseSaveCanvasArg>(() => {
|
||||
const onSave = (imageDTO: ImageDTO) => {
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsCogView4, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
selectIsChatGTP4o,
|
||||
selectIsCogView4,
|
||||
selectIsImagen3,
|
||||
selectIsSD3,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CanvasEntityType } from 'features/controlLayers/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -8,23 +13,25 @@ import { assert } from 'tsafe';
|
||||
export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
const isCogView4 = useAppSelector(selectIsCogView4);
|
||||
const isImagen3 = useAppSelector(selectIsImagen3);
|
||||
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
|
||||
|
||||
const isEntityTypeEnabled = useMemo<boolean>(() => {
|
||||
switch (entityType) {
|
||||
case 'reference_image':
|
||||
return !isSD3 && !isCogView4;
|
||||
return !isSD3 && !isCogView4 && !isImagen3;
|
||||
case 'regional_guidance':
|
||||
return !isSD3 && !isCogView4;
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
case 'control_layer':
|
||||
return !isSD3 && !isCogView4;
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
case 'inpaint_mask':
|
||||
return true;
|
||||
return !isImagen3 && !isChatGPT4o;
|
||||
case 'raster_layer':
|
||||
return true;
|
||||
return !isImagen3 && !isChatGPT4o;
|
||||
default:
|
||||
assert<Equals<typeof entityType, never>>(false);
|
||||
}
|
||||
}, [entityType, isSD3, isCogView4]);
|
||||
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
|
||||
|
||||
return isEntityTypeEnabled;
|
||||
};
|
||||
|
||||
@@ -41,7 +41,7 @@ export const useSaveLayerToAssets = () => {
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
});
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
|
||||
log.error({ error: serializeError(error as Error) }, 'Problem copying layer to clipboard');
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('toast.problemSavingLayer'),
|
||||
|
||||
@@ -519,7 +519,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
|
||||
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
|
||||
return imageDTO;
|
||||
} catch (error) {
|
||||
this.log.error({ rasterizeArgs, error: serializeError(error) }, 'Failed to rasterize entity');
|
||||
this.log.error({ rasterizeArgs, error: serializeError(error as Error) }, 'Failed to rasterize entity');
|
||||
throw error;
|
||||
} finally {
|
||||
this.manager.stateApi.$rasterizingAdapter.set(null);
|
||||
|
||||
@@ -346,7 +346,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
|
||||
// If the user is not holding shift, the transform is retaining aspect ratio. It's not possible to snap to the grid
|
||||
// in this case, because that would change the aspect ratio. So, we only snap to the grid when shift is held.
|
||||
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getGridSize() : 1;
|
||||
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getPositionGridSize() : 1;
|
||||
|
||||
// We need to snap the anchor to the selected grid size, but the positions provided to this callback are absolute,
|
||||
// scaled coordinates. They need to be converted to stage coordinates, snapped, then converted back to absolute
|
||||
@@ -464,7 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
const { rect } = this.manager.stateApi.getBbox();
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const gridSize = this.manager.stateApi.getPositionGridSize();
|
||||
const width = this.konva.proxyRect.width();
|
||||
const height = this.konva.proxyRect.height();
|
||||
const scaleX = rect.width / width;
|
||||
@@ -498,7 +498,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
const { rect } = this.manager.stateApi.getBbox();
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const gridSize = this.manager.stateApi.getPositionGridSize();
|
||||
const width = this.konva.proxyRect.width();
|
||||
const height = this.konva.proxyRect.height();
|
||||
const scaleX = rect.width / width;
|
||||
@@ -523,7 +523,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
|
||||
onDragMove = () => {
|
||||
// Snap the interaction rect to the grid
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const gridSize = this.manager.stateApi.getPositionGridSize();
|
||||
this.konva.proxyRect.x(roundToMultiple(this.konva.proxyRect.x(), gridSize));
|
||||
this.konva.proxyRect.y(roundToMultiple(this.konva.proxyRect.y(), gridSize));
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url));
|
||||
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
|
||||
if (imageElementResult.isErr()) {
|
||||
// Image loading failed (e.g. the URL to the "physical" image is invalid)
|
||||
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));
|
||||
|
||||
@@ -493,7 +493,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
* Gets the _positional_ grid size for the current canvas. Note that this is not the same as bbox grid size, which is
|
||||
* based on the currently-selected model.
|
||||
*/
|
||||
getGridSize = (): number => {
|
||||
getPositionGridSize = (): number => {
|
||||
const snapToGrid = this.getSettings().snapToGrid;
|
||||
if (!snapToGrid) {
|
||||
return 1;
|
||||
|
||||
@@ -4,8 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
|
||||
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import Konva from 'konva';
|
||||
import { noop } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
@@ -178,6 +180,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
// Listen for the bbox overlay setting to update the overlay's visibility
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBboxOverlay, this.render));
|
||||
|
||||
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
|
||||
|
||||
// Update on busy state changes
|
||||
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
|
||||
}
|
||||
@@ -218,12 +223,25 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
|
||||
this.syncOverlay();
|
||||
|
||||
const model = this.manager.stateApi.runSelector(selectModel);
|
||||
|
||||
this.konva.transformer.setAttrs({
|
||||
listening: tool === 'bbox',
|
||||
enabledAnchors: tool === 'bbox' ? ALL_ANCHORS : NO_ANCHORS,
|
||||
enabledAnchors: this.getEnabledAnchors(tool, model),
|
||||
});
|
||||
};
|
||||
|
||||
getEnabledAnchors = (tool: Tool, model?: ModelIdentifierField | null): string[] => {
|
||||
if (tool !== 'bbox') {
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
|
||||
// The bbox is not resizable in these modes
|
||||
return NO_ANCHORS;
|
||||
}
|
||||
return ALL_ANCHORS;
|
||||
};
|
||||
|
||||
syncOverlay = () => {
|
||||
const bboxOverlay = this.manager.stateApi.getSettings().bboxOverlay;
|
||||
|
||||
@@ -251,7 +269,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
onDragMove = () => {
|
||||
// The grid size here is the _position_ grid size, not the _dimension_ grid size - it is not constratined by the
|
||||
// currently-selected model.
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const gridSize = this.manager.stateApi.getPositionGridSize();
|
||||
const bbox = this.manager.stateApi.getBbox();
|
||||
const bboxRect: Rect = {
|
||||
...bbox.rect,
|
||||
|
||||
@@ -476,15 +476,24 @@ export function getImageDataTransparency(imageData: ImageData): Transparency {
|
||||
/**
|
||||
* Loads an image from a URL and returns a promise that resolves with the loaded image element.
|
||||
* @param src The image source URL
|
||||
* @param fetchUrlFirst Whether to fetch the image's URL first, assuming the provided `src` will redirect to a different URL. This addresses an issue where CORS headers are dropped during a redirect.
|
||||
* @returns A promise that resolves with the loaded image element
|
||||
*/
|
||||
export function loadImage(src: string): Promise<HTMLImageElement> {
|
||||
export async function loadImage(src: string, fetchUrlFirst?: boolean): Promise<HTMLImageElement> {
|
||||
const authToken = $authToken.get();
|
||||
let url = src;
|
||||
if (authToken && fetchUrlFirst) {
|
||||
const response = await fetch(`${src}?url_only=true`, { credentials: 'include' });
|
||||
const data = await response.json();
|
||||
url = data.url;
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const imageElement = new Image();
|
||||
imageElement.onload = () => resolve(imageElement);
|
||||
imageElement.onerror = (error) => reject(error);
|
||||
imageElement.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
|
||||
imageElement.src = src;
|
||||
imageElement.src = url;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ export type Extents = {
|
||||
|
||||
/**
|
||||
* Get the bounding box of an image.
|
||||
* @param buffer The ArrayBuffer of the image to get the bounding box of.
|
||||
* @param buffer The ArrayBufferLike of the image to get the bounding box of.
|
||||
* @param width The width of the image.
|
||||
* @param height The height of the image.
|
||||
* @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
|
||||
*/
|
||||
const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height: number): Extents | null => {
|
||||
const getImageDataBboxArrayBufferLike = (buffer: ArrayBufferLike, width: number, height: number): Extents | null => {
|
||||
let minX = width;
|
||||
let minY = height;
|
||||
let maxX = -1;
|
||||
@@ -50,7 +50,7 @@ const getImageDataBboxArrayBuffer = (buffer: ArrayBuffer, width: number, height:
|
||||
|
||||
export type GetBboxTask = {
|
||||
type: 'get_bbox';
|
||||
data: { id: string; buffer: ArrayBuffer; width: number; height: number };
|
||||
data: { id: string; buffer: ArrayBufferLike; width: number; height: number };
|
||||
};
|
||||
|
||||
type TaskWithTimestamps<T extends Record<string, unknown>> = T & { started: number | null; finished: number | null };
|
||||
@@ -95,7 +95,7 @@ function processNextTask() {
|
||||
// Process the task
|
||||
if (task.type === 'get_bbox') {
|
||||
const { buffer, width, height, id } = task.data;
|
||||
const extents = getImageDataBboxArrayBuffer(buffer, width, height);
|
||||
const extents = getImageDataBboxArrayBufferLike(buffer, width, height);
|
||||
const result: ExtentsResult = {
|
||||
type: 'extents',
|
||||
data: { id, extents },
|
||||
|
||||
@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
|
||||
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
|
||||
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import { isEqual, merge } from 'lodash-es';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type {
|
||||
ApiModelConfig,
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
FLUXReduxModelConfig,
|
||||
@@ -67,7 +68,7 @@ import type {
|
||||
IPMethodV2,
|
||||
T2IAdapterConfig,
|
||||
} from './types';
|
||||
import { getEntityIdentifier, isRenderableEntity } from './types';
|
||||
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
|
||||
import {
|
||||
converters,
|
||||
getControlLayerState,
|
||||
@@ -76,6 +77,7 @@ import {
|
||||
getReferenceImageState,
|
||||
getRegionalGuidanceState,
|
||||
imageDTOToImageWithDims,
|
||||
initialChatGPT4oReferenceImage,
|
||||
initialControlLoRA,
|
||||
initialControlNet,
|
||||
initialFLUXRedux,
|
||||
@@ -644,7 +646,10 @@ export const canvasSlice = createSlice({
|
||||
referenceImageIPAdapterModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
|
||||
EntityIdentifierPayload<
|
||||
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
|
||||
'reference_image'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, modelConfig } = action.payload;
|
||||
@@ -652,14 +657,36 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
|
||||
const oldModel = entity.ipAdapter.model;
|
||||
|
||||
// First set the new model
|
||||
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching from ip_adapter to flux_redux
|
||||
if (isEqual(oldModel, entity.ipAdapter.model)) {
|
||||
// Nothing changed, so we don't need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
|
||||
// When we switch the model, we keep the image the same, but change the other parameters.
|
||||
|
||||
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
|
||||
// Switching to chatgpt-4o ref image
|
||||
entity.ipAdapter = {
|
||||
...initialChatGPT4oReferenceImage,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching to flux_redux
|
||||
entity.ipAdapter = {
|
||||
...initialFLUXRedux,
|
||||
image: entity.ipAdapter.image,
|
||||
@@ -668,17 +695,13 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching from flux_redux to ip_adapter
|
||||
if (entity.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching to ip_adapter
|
||||
entity.ipAdapter = {
|
||||
...initialIPAdapter,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter') {
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (entity.ipAdapter.model?.base === 'flux') {
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
@@ -686,6 +709,7 @@ export const canvasSlice = createSlice({
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
}
|
||||
return;
|
||||
}
|
||||
},
|
||||
referenceImageIPAdapterCLIPVisionModelChanged: (
|
||||
@@ -1139,7 +1163,21 @@ export const canvasSlice = createSlice({
|
||||
syncScaledSize(state);
|
||||
},
|
||||
bboxChangedFromCanvas: (state, action: PayloadAction<IRect>) => {
|
||||
state.bbox.rect = action.payload;
|
||||
const newBboxRect = action.payload;
|
||||
const oldBboxRect = state.bbox.rect;
|
||||
|
||||
state.bbox.rect = newBboxRect;
|
||||
|
||||
if (newBboxRect.width === oldBboxRect.width && newBboxRect.height === oldBboxRect.height) {
|
||||
return;
|
||||
}
|
||||
|
||||
const oldAspectRatio = state.bbox.aspectRatio.value;
|
||||
const newAspectRatio = newBboxRect.width / newBboxRect.height;
|
||||
|
||||
if (oldAspectRatio === newAspectRatio) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(psyche): Figure out a way to handle this without resetting the aspect ratio on every change.
|
||||
// This action is dispatched when the user resizes or moves the bbox from the canvas. For now, when the user
|
||||
@@ -1198,6 +1236,40 @@ export const canvasSlice = createSlice({
|
||||
state.bbox.aspectRatio.id = id;
|
||||
if (id === 'Free') {
|
||||
state.bbox.aspectRatio.isLocked = false;
|
||||
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
|
||||
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
|
||||
if (id === '16:9') {
|
||||
state.bbox.rect.width = 1408;
|
||||
state.bbox.rect.height = 768;
|
||||
} else if (id === '4:3') {
|
||||
state.bbox.rect.width = 1280;
|
||||
state.bbox.rect.height = 896;
|
||||
} else if (id === '1:1') {
|
||||
state.bbox.rect.width = 1024;
|
||||
state.bbox.rect.height = 1024;
|
||||
} else if (id === '3:4') {
|
||||
state.bbox.rect.width = 896;
|
||||
state.bbox.rect.height = 1280;
|
||||
} else if (id === '9:16') {
|
||||
state.bbox.rect.width = 768;
|
||||
state.bbox.rect.height = 1408;
|
||||
}
|
||||
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
|
||||
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
|
||||
if (id === '3:2') {
|
||||
state.bbox.rect.width = 1536;
|
||||
state.bbox.rect.height = 1024;
|
||||
} else if (id === '1:1') {
|
||||
state.bbox.rect.width = 1024;
|
||||
state.bbox.rect.height = 1024;
|
||||
} else if (id === '2:3') {
|
||||
state.bbox.rect.width = 1024;
|
||||
state.bbox.rect.height = 1536;
|
||||
}
|
||||
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
} else {
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
|
||||
@@ -1670,6 +1742,13 @@ export const canvasSlice = createSlice({
|
||||
const base = model?.base;
|
||||
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
|
||||
state.bbox.modelBase = base;
|
||||
if (base === 'imagen3' || base === 'chatgpt-4o') {
|
||||
state.bbox.aspectRatio.isLocked = true;
|
||||
state.bbox.aspectRatio.value = 1;
|
||||
state.bbox.aspectRatio.id = '1:1';
|
||||
state.bbox.rect.width = 1024;
|
||||
state.bbox.rect.height = 1024;
|
||||
}
|
||||
syncScaledSize(state);
|
||||
}
|
||||
});
|
||||
@@ -1802,6 +1881,10 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
};
|
||||
|
||||
const syncScaledSize = (state: CanvasState) => {
|
||||
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
|
||||
// Imagen3 has fixed sizes. Scaled bbox is not supported.
|
||||
return;
|
||||
}
|
||||
if (state.bbox.scaleMethod === 'auto') {
|
||||
// Sync both aspect ratio and size
|
||||
const { width, height } = state.bbox.rect;
|
||||
|
||||
@@ -380,6 +380,8 @@ export const selectIsSDXL = createParamsSelector((params) => params.model?.base
|
||||
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
|
||||
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
|
||||
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
|
||||
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
|
||||
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
|
||||
|
||||
export const selectModel = createParamsSelector((params) => params.model);
|
||||
export const selectModelKey = createParamsSelector((params) => params.model?.key);
|
||||
|
||||
@@ -245,6 +245,18 @@ const zFLUXReduxConfig = z.object({
|
||||
});
|
||||
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
|
||||
|
||||
const zChatGPT4oReferenceImageConfig = z.object({
|
||||
type: z.literal('chatgpt_4o_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
/**
|
||||
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
|
||||
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
|
||||
* there will be no way to switch between ref image types.
|
||||
*/
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
|
||||
|
||||
const zCanvasEntityBase = z.object({
|
||||
id: zId,
|
||||
name: zName,
|
||||
@@ -254,15 +266,19 @@ const zCanvasEntityBase = z.object({
|
||||
|
||||
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
|
||||
type: z.literal('reference_image'),
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
|
||||
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
|
||||
});
|
||||
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
|
||||
|
||||
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
|
||||
export const isIPAdapterConfig = (config: CanvasReferenceImageState['ipAdapter']): config is IPAdapterConfig =>
|
||||
config.type === 'ip_adapter';
|
||||
|
||||
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
|
||||
export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']): config is FLUXReduxConfig =>
|
||||
config.type === 'flux_redux';
|
||||
export const isChatGPT4oReferenceImageConfig = (
|
||||
config: CanvasReferenceImageState['ipAdapter']
|
||||
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
|
||||
|
||||
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
|
||||
export type FillStyle = z.infer<typeof zFillStyle>;
|
||||
@@ -387,9 +403,18 @@ export type StagingAreaImage = {
|
||||
offsetY: number;
|
||||
};
|
||||
|
||||
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
|
||||
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
|
||||
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
|
||||
zImagen3AspectRatioID.safeParse(v).success;
|
||||
|
||||
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
|
||||
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
|
||||
zChatGPT4oAspectRatioID.safeParse(v).success;
|
||||
|
||||
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
|
||||
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
|
||||
|
||||
const zCanvasState = z.object({
|
||||
_version: z.literal(3),
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ChatGPT4oReferenceImageConfig,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
FLUXReduxConfig,
|
||||
@@ -77,6 +78,11 @@ export const initialFLUXRedux: FLUXReduxConfig = {
|
||||
model: null,
|
||||
imageInfluence: 'highest',
|
||||
};
|
||||
export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
|
||||
type: 'chatgpt_4o_reference_image',
|
||||
image: null,
|
||||
model: null,
|
||||
};
|
||||
export const initialT2IAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
model: null,
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
|
||||
type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
|
||||
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
|
||||
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
|
||||
|
||||
export interface DynamicPromptsState {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
@@ -116,7 +116,7 @@ export const createNewCanvasEntityFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
break;
|
||||
@@ -238,7 +238,7 @@ export const newCanvasFromImage = (arg: {
|
||||
break;
|
||||
}
|
||||
case 'reference_image': {
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
|
||||
const ipAdapter = deepClone(selectDefaultRefImageConfig(getState()));
|
||||
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
|
||||
dispatch(canvasReset());
|
||||
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
|
||||
|
||||
@@ -58,7 +58,7 @@ const LoRASelect = () => {
|
||||
const noOptionsMessage = useCallback(() => t('models.noMatchingLoRAs'), [t]);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length}>
|
||||
<FormControl isDisabled={!options.length} gap={2}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
|
||||
@@ -12,63 +12,45 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
|
||||
import { UNAUTHORIZED_TOAST_ID } from 'services/events/onModelInstallError';
|
||||
import {
|
||||
useGetHFTokenStatusQuery,
|
||||
useResetHFTokenMutation,
|
||||
useSetHFTokenMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const HFToken = () => {
|
||||
const { t } = useTranslation();
|
||||
const isHFTokenEnabled = useFeatureStatus('hfToken');
|
||||
const [token, setToken] = useState('');
|
||||
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
|
||||
const [trigger, { isLoading, isUninitialized }] = useSetHFTokenMutation();
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setToken(e.target.value);
|
||||
}, []);
|
||||
const onClick = useCallback(() => {
|
||||
trigger({ token })
|
||||
.unwrap()
|
||||
.then((res) => {
|
||||
if (res === 'valid') {
|
||||
setToken('');
|
||||
toast({
|
||||
id: UNAUTHORIZED_TOAST_ID,
|
||||
title: t('modelManager.hfTokenSaved'),
|
||||
status: 'success',
|
||||
duration: 3000,
|
||||
});
|
||||
}
|
||||
});
|
||||
}, [t, token, trigger]);
|
||||
|
||||
const error = useMemo(() => {
|
||||
if (!currentData || isUninitialized || isLoading) {
|
||||
return null;
|
||||
switch (currentData) {
|
||||
case 'invalid':
|
||||
return t('modelManager.hfTokenInvalidErrorMessage');
|
||||
case 'unknown':
|
||||
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
|
||||
case 'valid':
|
||||
case undefined:
|
||||
return null;
|
||||
default:
|
||||
assert<Equals<never, typeof currentData>>(false, 'Unexpected HF token status');
|
||||
}
|
||||
if (currentData === 'invalid') {
|
||||
return t('modelManager.hfTokenInvalidErrorMessage');
|
||||
}
|
||||
if (currentData === 'unknown') {
|
||||
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
|
||||
}
|
||||
return null;
|
||||
}, [currentData, isLoading, isUninitialized, t]);
|
||||
}, [currentData, t]);
|
||||
|
||||
if (!currentData || currentData === 'valid') {
|
||||
if (!currentData) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex borderRadius="base" w="full">
|
||||
<FormControl isInvalid={!isUninitialized && Boolean(error)} orientation="vertical">
|
||||
<FormControl isInvalid={Boolean(error)} orientation="vertical">
|
||||
<FormLabel>{t('modelManager.hfTokenLabel')}</FormLabel>
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input type="password" value={token} onChange={onChange} />
|
||||
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{error && <SetHFTokenInput />}
|
||||
{!error && <ResetHFTokenButton />}
|
||||
<FormHelperText>
|
||||
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
|
||||
</FormHelperText>
|
||||
@@ -77,3 +59,73 @@ export const HFToken = () => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
const PLACEHOLDER_TOKEN = Array.from({ length: 37 }, () => 'a').join('');
|
||||
|
||||
const ResetHFTokenButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [resetHFToken, { isLoading }] = useResetHFTokenMutation();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
resetHFToken()
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
toast({
|
||||
title: t('modelManager.hfTokenReset'),
|
||||
status: 'info',
|
||||
});
|
||||
});
|
||||
}, [resetHFToken, t]);
|
||||
|
||||
return (
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input type="password" value={PLACEHOLDER_TOKEN} isDisabled />
|
||||
<Button onClick={onClick} size="sm" isLoading={isLoading}>
|
||||
{t('common.reset')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
ResetHFTokenButton.displayName = 'ResetHFTokenButton';
|
||||
|
||||
const SetHFTokenInput = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [token, setToken] = useState('');
|
||||
const [trigger, { isLoading }] = useSetHFTokenMutation();
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setToken(e.target.value);
|
||||
}, []);
|
||||
const onClick = useCallback(() => {
|
||||
trigger({ token })
|
||||
.unwrap()
|
||||
.then((res) => {
|
||||
switch (res) {
|
||||
case 'valid':
|
||||
setToken('');
|
||||
toast({
|
||||
title: t('modelManager.hfTokenSaved'),
|
||||
status: 'success',
|
||||
});
|
||||
break;
|
||||
case 'invalid':
|
||||
case 'unknown':
|
||||
default:
|
||||
toast({
|
||||
title: t('modelManager.hfTokenUnableToVerify'),
|
||||
status: 'error',
|
||||
});
|
||||
break;
|
||||
}
|
||||
});
|
||||
}, [t, token, trigger]);
|
||||
|
||||
return (
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input type="password" value={token} onChange={onChange} />
|
||||
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
SetHFTokenInput.displayName = 'SetHFTokenInput';
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HFToken } from './HFToken';
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
@@ -16,7 +15,6 @@ export const HuggingFaceForm = memo(() => {
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
const isHFTokenEnabled = useFeatureStatus('hfToken');
|
||||
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
|
||||
|
||||
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
|
||||
const [installModel] = useInstallModel();
|
||||
@@ -68,7 +66,7 @@ export const HuggingFaceForm = memo(() => {
|
||||
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
{currentData !== 'valid' && <HFToken />}
|
||||
{isHFTokenEnabled && <HFToken />}
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ type Props = {
|
||||
base: BaseModelType;
|
||||
};
|
||||
|
||||
const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
any: 'base',
|
||||
'sd-1': 'green',
|
||||
'sd-2': 'teal',
|
||||
@@ -15,12 +15,14 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'gold',
|
||||
cogview4: 'orange',
|
||||
cogview4: 'red',
|
||||
imagen3: 'pink',
|
||||
'chatgpt-4o': 'pink',
|
||||
};
|
||||
|
||||
const ModelBaseBadge = ({ base }: Props) => {
|
||||
return (
|
||||
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
|
||||
<Badge flexGrow={0} flexShrink={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
|
||||
{MODEL_TYPE_SHORT_MAP[base]}
|
||||
</Badge>
|
||||
);
|
||||
|
||||
@@ -17,6 +17,7 @@ const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
|
||||
bnb_quantized_int8b: 'bnb_quantized_int8b',
|
||||
bnb_quantized_nf4b: 'quantized',
|
||||
gguf_quantized: 'gguf',
|
||||
api: 'api',
|
||||
};
|
||||
|
||||
const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
|
||||
@@ -30,6 +31,7 @@ const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
|
||||
bnb_quantized_int8b: 'base',
|
||||
bnb_quantized_nf4b: 'base',
|
||||
gguf_quantized: 'base',
|
||||
api: 'base',
|
||||
};
|
||||
|
||||
const ModelFormatBadge = ({ format }: Props) => {
|
||||
|
||||
@@ -15,7 +15,6 @@ const ModelImage = ({ image_url }: Props) => {
|
||||
<Flex
|
||||
height={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
|
||||
bg="base.650"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
|
||||
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
|
||||
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
|
||||
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
|
||||
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
|
||||
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
|
||||
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
|
||||
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
@@ -23,6 +25,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isChatGPT4oModelFieldInputInstance,
|
||||
isChatGPT4oModelFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isCLIPGEmbedModelFieldInputInstance,
|
||||
@@ -57,6 +61,8 @@ import {
|
||||
isImageFieldInputTemplate,
|
||||
isImageGeneratorFieldInputInstance,
|
||||
isImageGeneratorFieldInputTemplate,
|
||||
isImagen3ModelFieldInputInstance,
|
||||
isImagen3ModelFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@@ -394,6 +400,20 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen3ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen3ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isChatGPT4oModelFieldInputTemplate(template)) {
|
||||
if (!isChatGPT4oModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isColorFieldInputTemplate(template)) {
|
||||
if (!isColorFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedMode
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,32 +29,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedMo
|
||||
|
||||
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: CLIPGEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -35,32 +30,15 @@ const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -15,12 +12,10 @@ type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedMo
|
||||
|
||||
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: CLIPLEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -35,32 +30,15 @@ const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const ChatGPT4oModelFieldInputComponent = (
|
||||
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldChatGPT4oModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ChatGPT4oModelFieldInputComponent);
|
||||
@@ -1,16 +1,13 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
ControlLoRAModelFieldInputInstance,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import { type ControlLoRAModelConfig, isControlLoRAModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -18,12 +15,10 @@ type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRA
|
||||
|
||||
const ControlLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlLoRAModel();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: ControlLoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -38,32 +33,15 @@ const ControlLoRAModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isControlLoRAModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -17,7 +15,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: ControlNetModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -33,25 +31,14 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -31,25 +29,15 @@ const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -18,7 +16,7 @@ const FluxReduxModelFieldInputComponent = (
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxReduxModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: FLUXReduxModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,19 +32,14 @@ const FluxReduxModelFieldInputComponent = (
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -15,11 +12,9 @@ type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFie
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,27 +29,15 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -17,7 +15,7 @@ const IPAdapterModelFieldInputComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: IPAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -33,19 +31,14 @@ const IPAdapterModelFieldInputComponent = (
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen3ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen3Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen3ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen3ModelFieldInputComponent);
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLLaVAModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: LlavaOnevisionConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -32,23 +30,14 @@ const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: LoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -32,23 +30,14 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -31,25 +29,15 @@ const MainModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type Props<T extends AnyModelConfig> = {
|
||||
value: ModelIdentifierField | undefined;
|
||||
modelConfigs: T[];
|
||||
isLoadingConfigs: boolean;
|
||||
onChange: (value: T | null) => void;
|
||||
required: boolean;
|
||||
groupByType?: boolean;
|
||||
};
|
||||
|
||||
const _ModelFieldCombobox = <T extends AnyModelConfig>({
|
||||
value: _value,
|
||||
modelConfigs,
|
||||
isLoadingConfigs,
|
||||
onChange: _onChange,
|
||||
required,
|
||||
groupByType,
|
||||
}: Props<T>) => {
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading: isLoadingConfigs,
|
||||
selectedModel: _value,
|
||||
groupByType,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
gap={2}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export const ModelFieldCombobox = typedMemo(_ModelFieldCombobox);
|
||||
@@ -1,9 +1,7 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
@@ -17,7 +15,7 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetModelConfigsQuery();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: AnyModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -41,26 +39,15 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}, [data]);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
groupByType: true,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
groupByType
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
@@ -19,7 +17,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,25 +32,15 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -31,29 +29,15 @@ const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -31,25 +29,15 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSigLipModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -18,7 +16,7 @@ const SigLipModelFieldInputComponent = (
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSigLipModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: SigLipModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,19 +32,14 @@ const SigLipModelFieldInputComponent = (
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
@@ -21,7 +19,7 @@ const SpandrelImageToImageModelFieldInputComponent = (
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -37,19 +35,14 @@ const SpandrelImageToImageModelFieldInputComponent = (
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -18,7 +16,7 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
|
||||
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: T2IAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -34,19 +32,14 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -16,11 +12,9 @@ type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderMode
|
||||
|
||||
const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const isModelsTabDisabled = useAppSelector(selectIsModelsTabDisabled);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -35,31 +29,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -16,7 +14,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const _onChange = useCallback(
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
@@ -31,30 +29,15 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -121,6 +121,10 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
|
||||
'metadata_to_controlnets',
|
||||
'metadata_to_ip_adapters',
|
||||
'metadata_to_t2i_adapters',
|
||||
'google_imagen3_generate',
|
||||
'google_imagen3_edit',
|
||||
'chatgpt_create_image',
|
||||
'chatgpt_edit_image',
|
||||
];
|
||||
|
||||
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {
|
||||
|
||||
@@ -23,6 +23,7 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
ChatGPT4oModelFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
CLIPGEmbedModelFieldValue,
|
||||
CLIPLEmbedModelFieldValue,
|
||||
@@ -38,6 +39,7 @@ import type {
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
ImageGeneratorFieldValue,
|
||||
Imagen3ModelFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
@@ -61,6 +63,7 @@ import type {
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zChatGPT4oModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
@@ -76,6 +79,7 @@ import {
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zImageGeneratorFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
@@ -512,6 +516,12 @@ export const nodesSlice = createSlice({
|
||||
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
|
||||
},
|
||||
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImagen3ModelFieldValue);
|
||||
},
|
||||
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@@ -679,6 +689,8 @@ export const {
|
||||
fieldFluxVAEModelValueChanged,
|
||||
fieldSigLipModelValueChanged,
|
||||
fieldFluxReduxModelValueChanged,
|
||||
fieldImagen3ModelValueChanged,
|
||||
fieldChatGPT4oModelValueChanged,
|
||||
fieldFloatGeneratorValueChanged,
|
||||
fieldIntegerGeneratorValueChanged,
|
||||
fieldStringGeneratorValueChanged,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type {
|
||||
BaseModelType,
|
||||
BoardField,
|
||||
Classification,
|
||||
ColorField,
|
||||
@@ -9,9 +10,10 @@ import type {
|
||||
ModelIdentifierField,
|
||||
ProgressImage,
|
||||
SchedulerField,
|
||||
SubModelType,
|
||||
T2IAdapterField,
|
||||
} from 'features/nodes/types/common';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import type { Invocation, ModelType, S } from 'services/api/types';
|
||||
import type { Equals, Extends } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
@@ -34,6 +36,9 @@ describe('Common types', () => {
|
||||
|
||||
// Model component types
|
||||
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
|
||||
|
||||
// Misc types
|
||||
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
||||
|
||||
@@ -8,6 +8,11 @@ export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
});
|
||||
export type ImageField = z.infer<typeof zImageField>;
|
||||
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
|
||||
const zImageFieldCollection = z.array(zImageField);
|
||||
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
|
||||
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
|
||||
zImageFieldCollection.safeParse(field).success;
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
@@ -61,8 +66,20 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux', 'cogview4']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']);
|
||||
const zBaseModel = z.enum([
|
||||
'any',
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sd-3',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
'flux',
|
||||
'cogview4',
|
||||
'imagen3',
|
||||
'chatgpt-4o',
|
||||
]);
|
||||
export type BaseModelType = z.infer<typeof zBaseModel>;
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
@@ -98,6 +115,7 @@ const zSubModelType = z.enum([
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
export const zModelIdentifierField = z.object({
|
||||
key: z.string().min(1),
|
||||
hash: z.string().min(1),
|
||||
|
||||
@@ -248,6 +248,14 @@ const zFluxReduxModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxReduxModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImagen3ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Imagen3ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ChatGPT4oModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -298,6 +306,8 @@ const zStatefulFieldType = z.union([
|
||||
zFluxVAEModelFieldType,
|
||||
zSigLipModelFieldType,
|
||||
zFluxReduxModelFieldType,
|
||||
zImagen3ModelFieldType,
|
||||
zChatGPT4oModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
zFloatGeneratorFieldType,
|
||||
@@ -336,6 +346,8 @@ const modelFieldTypeNames = [
|
||||
zFluxVAEModelFieldType.shape.name.value,
|
||||
zSigLipModelFieldType.shape.name.value,
|
||||
zFluxReduxModelFieldType.shape.name.value,
|
||||
zImagen3ModelFieldType.shape.name.value,
|
||||
zChatGPT4oModelFieldType.shape.name.value,
|
||||
// Stateless model fields
|
||||
'UNetField',
|
||||
'VAEField',
|
||||
@@ -1177,6 +1189,42 @@ export const isFluxReduxModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
|
||||
// #endregion
|
||||
|
||||
// #region Imagen3ModelField
|
||||
export const zImagen3ModelFieldValue = zModelIdentifierField.optional();
|
||||
const zImagen3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImagen3ModelFieldValue,
|
||||
});
|
||||
const zImagen3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImagen3ModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImagen3ModelFieldValue,
|
||||
});
|
||||
export type Imagen3ModelFieldValue = z.infer<typeof zImagen3ModelFieldValue>;
|
||||
export type Imagen3ModelFieldInputInstance = z.infer<typeof zImagen3ModelFieldInputInstance>;
|
||||
export type Imagen3ModelFieldInputTemplate = z.infer<typeof zImagen3ModelFieldInputTemplate>;
|
||||
export const isImagen3ModelFieldInputInstance = buildInstanceTypeGuard(zImagen3ModelFieldInputInstance);
|
||||
export const isImagen3ModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ChatGPT4oModelField
|
||||
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
|
||||
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zChatGPT4oModelFieldValue,
|
||||
});
|
||||
const zChatGPT4oModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zChatGPT4oModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zChatGPT4oModelFieldValue,
|
||||
});
|
||||
export type ChatGPT4oModelFieldValue = z.infer<typeof zChatGPT4oModelFieldValue>;
|
||||
export type ChatGPT4oModelFieldInputInstance = z.infer<typeof zChatGPT4oModelFieldInputInstance>;
|
||||
export type ChatGPT4oModelFieldInputTemplate = z.infer<typeof zChatGPT4oModelFieldInputTemplate>;
|
||||
export const isChatGPT4oModelFieldInputInstance = buildInstanceTypeGuard(zChatGPT4oModelFieldInputInstance);
|
||||
export const isChatGPT4oModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<ChatGPT4oModelFieldInputTemplate>('ChatGPT4oModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -1808,6 +1856,8 @@ export const zStatefulFieldValue = z.union([
|
||||
zControlLoRAModelFieldValue,
|
||||
zSigLipModelFieldValue,
|
||||
zFluxReduxModelFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zChatGPT4oModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zFloatGeneratorFieldValue,
|
||||
@@ -1898,6 +1948,8 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zControlLoRAModelFieldInputTemplate,
|
||||
zSigLipModelFieldInputTemplate,
|
||||
zFluxReduxModelFieldInputTemplate,
|
||||
zImagen3ModelFieldInputTemplate,
|
||||
zChatGPT4oModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
||||
@@ -1,38 +1,63 @@
|
||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import randomInt from 'common/util/randomInt';
|
||||
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { range } from 'lodash-es';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
|
||||
import type { Batch, EnqueueBatchArg } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { ConditioningNodes, NoiseNodes } from './types';
|
||||
const getExtendedPrompts = (arg: {
|
||||
seedBehaviour: SeedBehaviour;
|
||||
iterations: number;
|
||||
prompts: string[];
|
||||
model: ModelIdentifierField;
|
||||
}): string[] => {
|
||||
const { seedBehaviour, iterations, prompts, model } = arg;
|
||||
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
|
||||
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
|
||||
// this, we need to create a batch of the right size by repeating the prompts.
|
||||
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o') {
|
||||
return range(iterations).flatMap(() => prompts);
|
||||
}
|
||||
return prompts;
|
||||
};
|
||||
|
||||
export const prepareLinearUIBatch = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
prepend: boolean,
|
||||
noise: Invocation<NoiseNodes>,
|
||||
posCond: Invocation<ConditioningNodes>,
|
||||
origin: 'canvas' | 'workflows' | 'upscaling',
|
||||
destination: 'canvas' | 'gallery'
|
||||
): EnqueueBatchArg => {
|
||||
export const prepareLinearUIBatch = (arg: {
|
||||
state: RootState;
|
||||
g: Graph;
|
||||
prepend: boolean;
|
||||
seedFieldIdentifier?: FieldIdentifier;
|
||||
positivePromptFieldIdentifier: FieldIdentifier;
|
||||
origin: 'canvas' | 'workflows' | 'upscaling';
|
||||
destination: 'canvas' | 'gallery';
|
||||
}): EnqueueBatchArg => {
|
||||
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
|
||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
|
||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||
|
||||
assert(model, 'No model found in state when preparing batch');
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
|
||||
// add seeds first to ensure the output order groups the prompts
|
||||
if (seedBehaviour === 'PER_PROMPT') {
|
||||
if (seedFieldIdentifier && seedBehaviour === 'PER_PROMPT') {
|
||||
const seeds = generateSeeds({
|
||||
count: prompts.length * iterations,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
|
||||
start:
|
||||
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
firstBatchDatumList.push({
|
||||
node_path: noise.id,
|
||||
field_name: 'seed',
|
||||
node_path: seedFieldIdentifier.nodeId,
|
||||
field_name: seedFieldIdentifier.fieldName,
|
||||
items: seeds,
|
||||
});
|
||||
|
||||
@@ -43,16 +68,18 @@ export const prepareLinearUIBatch = (
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
} else {
|
||||
} else if (seedFieldIdentifier && seedBehaviour === 'PER_ITERATION') {
|
||||
// seedBehaviour = SeedBehaviour.PerRun
|
||||
const seeds = generateSeeds({
|
||||
count: iterations,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
|
||||
start:
|
||||
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
secondBatchDatumList.push({
|
||||
node_path: noise.id,
|
||||
field_name: 'seed',
|
||||
node_path: seedFieldIdentifier.nodeId,
|
||||
field_name: seedFieldIdentifier.fieldName,
|
||||
items: seeds,
|
||||
});
|
||||
|
||||
@@ -66,12 +93,12 @@ export const prepareLinearUIBatch = (
|
||||
data.push(secondBatchDatumList);
|
||||
}
|
||||
|
||||
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
||||
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
|
||||
|
||||
// zipped batch of prompts
|
||||
firstBatchDatumList.push({
|
||||
node_path: posCond.id,
|
||||
field_name: 'prompt',
|
||||
node_path: positivePromptFieldIdentifier.nodeId,
|
||||
field_name: positivePromptFieldIdentifier.fieldName,
|
||||
items: extendedPrompts,
|
||||
});
|
||||
|
||||
@@ -83,9 +110,9 @@ export const prepareLinearUIBatch = (
|
||||
items: extendedPrompts,
|
||||
});
|
||||
|
||||
if (shouldConcatPrompts && model?.base === 'sdxl') {
|
||||
if (shouldConcatPrompts && model.base === 'sdxl') {
|
||||
firstBatchDatumList.push({
|
||||
node_path: posCond.id,
|
||||
node_path: positivePromptFieldIdentifier.nodeId,
|
||||
field_name: 'style',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -12,7 +12,7 @@ import { getBoardField, selectPresetModifiedPrompts } from './graphBuilderUtils'
|
||||
|
||||
export const buildMultidiffusionUpscaleGraph = async (
|
||||
state: RootState
|
||||
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }> => {
|
||||
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
|
||||
const {
|
||||
model,
|
||||
upscaleCfgScale: cfg_scale,
|
||||
@@ -243,5 +243,9 @@ export const buildMultidiffusionUpscaleGraph = async (
|
||||
|
||||
g.addEdge(controlNetCollector, 'collection', tiledMultidiffusion, 'control');
|
||||
|
||||
return { g, noise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
CANVAS_OUTPUT_PREFIX,
|
||||
getBoardField,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
|
||||
if (generationMode !== 'txt2img' && generationMode !== 'img2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
log.debug({ generationMode }, 'Building GPT Image graph');
|
||||
|
||||
const model = selectMainModelConfig(state);
|
||||
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { positivePrompt } = selectPresetModifiedPrompts(state);
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
assert(model.base === 'chatgpt-4o', 'Model is not a ChatGPT 4o model');
|
||||
|
||||
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
|
||||
|
||||
const validRefImages = canvas.referenceImages.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => isChatGPT4oReferenceImageConfig(entity.ipAdapter))
|
||||
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0)
|
||||
.toReversed(); // sends them in order they are displayed in the list
|
||||
|
||||
let reference_images: ImageField[] | undefined = undefined;
|
||||
|
||||
if (validRefImages.length > 0) {
|
||||
reference_images = [];
|
||||
for (const entity of validRefImages) {
|
||||
assert(entity.ipAdapter.image, 'Image is required for reference image');
|
||||
reference_images.push({
|
||||
image_name: entity.ipAdapter.image.image_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('chatgpt_4o_txt2img_graph'));
|
||||
const gptImage = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'chatgpt_4o_generate_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
reference_images,
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
if (generationMode === 'img2img') {
|
||||
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
|
||||
const gptImage = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'chatgpt_4o_edit_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
base_image: { image_name },
|
||||
reference_images,
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for ChatGPT ');
|
||||
};
|
||||
@@ -19,7 +19,7 @@ import {
|
||||
getSizes,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -27,10 +27,7 @@ import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildCogView4Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'cogview4_denoise'>; posCond: Invocation<'cogview4_text_encoder'> }> => {
|
||||
export const buildCogView4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building CogView4 graph');
|
||||
|
||||
@@ -186,5 +183,9 @@ export const buildCogView4Graph = async (
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise: denoise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -22,7 +22,11 @@ import {
|
||||
getSizes,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import {
|
||||
type GraphBuilderReturn,
|
||||
type ImageOutputNodes,
|
||||
UnsupportedGenerationModeError,
|
||||
} from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
@@ -34,10 +38,7 @@ import { addIPAdapters } from './addIPAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
|
||||
export const buildFLUXGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building FLUX graph');
|
||||
|
||||
@@ -83,7 +84,9 @@ export const buildFLUXGraph = async (
|
||||
//
|
||||
// The other asserts above are just for sanity & type check and should never be hit, so they do not have
|
||||
// translations.
|
||||
assert(generationMode === 'inpaint' || generationMode === 'outpaint', t('toast.fluxFillIncompatibleWithT2IAndI2I'));
|
||||
if (generationMode === 'txt2img' || generationMode === 'img2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.fluxFillIncompatibleWithT2IAndI2I'));
|
||||
}
|
||||
|
||||
// FLUX Fill wants much higher guidance values than normal FLUX - silently "fix" the value for the user.
|
||||
// TODO(psyche): Figure out a way to alert the user that this is happening - maybe return warnings from the graph
|
||||
@@ -336,5 +339,9 @@ export const buildFLUXGraph = async (
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise: denoise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
CANVAS_OUTPUT_PREFIX,
|
||||
getBoardField,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildImagen3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagen3IncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
log.debug({ generationMode }, 'Building Imagen3 graph');
|
||||
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
|
||||
const model = selectMainModelConfig(state);
|
||||
|
||||
assert(model, 'No model found for Imagen3 graph');
|
||||
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
|
||||
assert(isImagen3AspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
|
||||
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
|
||||
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
|
||||
const imagen3 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen3_generate_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
enhance_prompt: true,
|
||||
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
});
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
|
||||
};
|
||||
@@ -23,7 +23,7 @@ import {
|
||||
getSizes,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -33,10 +33,7 @@ import { addRegions } from './addRegions';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSD1Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => {
|
||||
export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building SD1/SD2 graph');
|
||||
|
||||
@@ -316,5 +313,9 @@ export const buildSD1Graph = async (
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
getSizes,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -26,10 +26,7 @@ import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSD3Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'sd3_denoise'>; posCond: Invocation<'sd3_text_encoder'> }> => {
|
||||
export const buildSD3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building SD3 graph');
|
||||
|
||||
@@ -211,5 +208,9 @@ export const buildSD3Graph = async (
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise: denoise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: denoise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -23,7 +23,7 @@ import {
|
||||
getSizes,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -33,10 +33,7 @@ import { addRegions } from './addRegions';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSDXLGraph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => {
|
||||
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building SDXL graph');
|
||||
|
||||
@@ -323,5 +320,9 @@ export const buildSDXLGraph = async (
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise, posCond };
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: noise.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
|
||||
export type ImageOutputNodes =
|
||||
| 'l2i'
|
||||
| 'img_nsfw'
|
||||
@@ -23,11 +26,16 @@ export type MainModelLoaderNodes =
|
||||
| 'cogview4_model_loader';
|
||||
|
||||
export type VaeSourceNodes = 'seamless' | 'vae_loader';
|
||||
export type NoiseNodes = 'noise' | 'flux_denoise' | 'sd3_denoise' | 'cogview4_denoise';
|
||||
|
||||
export type ConditioningNodes =
|
||||
| 'compel'
|
||||
| 'sdxl_compel_prompt'
|
||||
| 'flux_text_encoder'
|
||||
| 'sd3_text_encoder'
|
||||
| 'cogview4_text_encoder';
|
||||
export type GraphBuilderReturn = {
|
||||
g: Graph;
|
||||
seedFieldIdentifier?: FieldIdentifier;
|
||||
positivePromptFieldIdentifier: FieldIdentifier;
|
||||
};
|
||||
|
||||
export class UnsupportedGenerationModeError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
ControlLoRAModelField: undefined,
|
||||
SigLipModelField: undefined,
|
||||
FluxReduxModelField: undefined,
|
||||
Imagen3ModelField: undefined,
|
||||
ChatGPT4oModelField: undefined,
|
||||
FloatGeneratorField: undefined,
|
||||
IntegerGeneratorField: undefined,
|
||||
StringGeneratorField: undefined,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user