Compare commits

..

3 Commits

Author SHA1 Message Date
Brandon Rising
bbb48c5475 Gitignore aider 2024-05-28 15:42:34 -04:00
Brandon Rising
e68c49167a aider: Fixed CPU memory issues by deleting large objects and forcing garbage collection.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Im having issues with cpu memory, not gpu memory

ASSISTANT: To address CPU memory issues, we can make the following changes:

1. Use `del` to explicitly delete large objects when they are no longer needed.
2. Use `gc.collect()` to force garbage collection after deleting large objects.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
=======
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        del converted_vae_checkpoint  # Free memory
        import gc
        gc.collect()
        vae.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```
2024-05-28 15:05:36 -04:00
Brandon Rising
46950a9bd0 aider: Improved memory usage efficiency by avoiding intermediate value storage and clearing the cache.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Could you improve the memory usage efficiency of this code?

ASSISTANT: To improve the memory usage efficiency of the code, we can make the following changes:

1. Use `with torch.no_grad()` to avoid storing intermediate values for backpropagation.
2. Move the model to the desired device only when necessary.
3. Clear the cache after moving the model to the desired device.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae.load_state_dict(converted_vae_checkpoint)
    vae.to(precision)
=======
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```
2024-05-28 15:03:43 -04:00
97 changed files with 3021 additions and 5647 deletions

1
.gitignore vendored
View File

@@ -188,3 +188,4 @@ installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*

View File

@@ -18,7 +18,6 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@@ -71,6 +70,3 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@@ -154,18 +154,6 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself

View File

@@ -3,7 +3,9 @@ import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
@@ -11,9 +13,11 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
@@ -21,8 +25,10 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
@@ -39,6 +45,11 @@ from .api.routers import (
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
@@ -108,7 +119,84 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False)

View File

@@ -98,13 +98,11 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@@ -114,12 +112,11 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod
@@ -128,13 +125,12 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])
@classmethod
@@ -171,7 +167,6 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
@@ -182,17 +177,15 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod
@@ -228,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
@@ -244,7 +237,6 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])
@abstractmethod
@@ -318,7 +310,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=False,
json_schema_serialization_defaults_required=True,
coerce_numbers_to_str=True,
)

View File

@@ -1,98 +0,0 @@
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)

View File

@@ -1,7 +1,6 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
"""
@@ -16,5 +15,3 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@@ -1,80 +0,0 @@
from typing import Optional
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import DenoiseMaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)

View File

@@ -1,138 +0,0 @@
from typing import Literal, Optional
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@@ -1,61 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.2",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@@ -1,65 +0,0 @@
import math
from typing import Tuple
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
"initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@@ -1,125 +0,0 @@
from functools import singledispatchmethod
import einops
import torch
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"i2l",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents

File diff suppressed because it is too large Load Diff

View File

@@ -1,127 +0,0 @@
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if use_fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -1,103 +0,0 @@
from typing import Literal
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation(
"lresize",
title="Resize Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation(
"lscale",
title="Scale Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)

View File

@@ -1,34 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation(
"scheduler",
title="Scheduler",
tags=["scheduler"],
category="latents",
version="1.0.0",
)
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)

View File

@@ -1,384 +0,0 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latent import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
gradient_mask=None,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@@ -3,8 +3,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -13,7 +14,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -98,9 +98,17 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
@@ -109,7 +117,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
@@ -136,7 +144,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
@@ -174,11 +182,19 @@ class InvocationCompleteEvent(InvocationEventBase):
__event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation")
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
@@ -207,7 +223,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
invocation: BaseInvocation,
error_type: str,
error_message: str,
error_traceback: str,

View File

@@ -2,19 +2,18 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
field_validator,
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from pydantic_core import CoreSchema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@@ -278,58 +277,73 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -760,7 +774,7 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@@ -777,12 +791,52 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""

View File

@@ -289,7 +289,7 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device: str | torch.device = "cuda",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
@@ -304,7 +304,7 @@ def prepare_control_image(
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
device (str, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.

View File

@@ -1,116 +0,0 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@@ -30,8 +30,12 @@ def convert_ldm_vae_to_diffusers(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
with torch.no_grad():
vae.load_state_dict(converted_vae_checkpoint)
del converted_vae_checkpoint # Free memory
import gc
gc.collect()
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
@@ -52,7 +56,11 @@ def convert_ckpt_to_diffusers(
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
pipe = pipe.to(precision)
with torch.no_grad():
del kwargs # Free memory
import gc
gc.collect()
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
@@ -75,7 +83,11 @@ def convert_controlnet_to_diffusers(
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
pipe = pipe.to(precision)
with torch.no_grad():
del kwargs # Free memory
import gc
gc.collect()
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:

View File

@@ -60,5 +60,5 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.util.util import SilenceWarnings
from .config import (
AnyModelConfig,

View File

@@ -11,6 +11,7 @@ import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -25,7 +26,6 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass

View File

@@ -1,36 +1,29 @@
import warnings
from contextlib import ContextDecorator
"""Context class to silence transformers and diffusers warnings."""
from diffusers.utils import logging as diffusers_logging
import warnings
from typing import Any
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
class SilenceWarnings(ContextDecorator):
"""A context manager that disables warnings from transformers & diffusers modules while active.
class SilenceWarnings(object):
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
As context manager:
```
with SilenceWarnings():
# do something
```
As decorator:
```
@SilenceWarnings()
def some_function():
# do something
```
"""
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None:
self._transformers_verbosity = transformers_logging.get_verbosity()
self._diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, *args) -> None:
transformers_logging.set_verbosity(self._transformers_verbosity)
diffusers_logging.set_verbosity(self._diffusers_verbosity)
def __exit__(self, *args: Any) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -1,9 +1,12 @@
import base64
import io
import os
import warnings
from pathlib import Path
from diffusers import logging as diffusers_logging
from PIL import Image
from transformers import logging as transformers_logging
# actual size of a gig
GIG = 1073741824
@@ -48,3 +51,21 @@ class Chdir(object):
def __exit__(self, *args):
os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -1021,8 +1021,7 @@
"float": "Kommazahlen",
"enum": "Aufzählung",
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
"editMode": "Im Workflow-Editor bearbeiten",
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
"editMode": "Im Workflow-Editor bearbeiten"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -148,8 +148,6 @@
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled"
},
@@ -377,23 +375,7 @@
"bulkDownloadRequestFailed": "Problem Preparing Download",
"bulkDownloadFailed": "Download Failed",
"problemDeletingImages": "Problem Deleting Images",
"problemDeletingImagesDesc": "One or more images could not be deleted",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",
"sideBySide": "Side-by-Side",
"hover": "Hover",
"swapImages": "Swap Images",
"compareOptions": "Comparison Options",
"stretchToFit": "Stretch to Fit",
"exitCompare": "Exit Compare",
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
"problemDeletingImagesDesc": "One or more images could not be deleted"
},
"hotkeys": {
"searchHotkeys": "Search Hotkeys",

View File

@@ -6,7 +6,7 @@
"settingsLabel": "Ajustes",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Flujos de trabajo",
"nodes": "Editor del flujo de trabajo",
"upload": "Subir imagen",
"load": "Cargar",
"statusDisconnected": "Desconectado",
@@ -14,7 +14,7 @@
"discordLabel": "Discord",
"back": "Atrás",
"loading": "Cargando",
"postprocessing": "Postprocesado",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar",
@@ -42,42 +42,7 @@
"copy": "Copiar",
"beta": "Beta",
"on": "En",
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
"installed": "Instalado",
"green": "Verde",
"editor": "Editor",
"orderBy": "Ordenar por",
"file": "Archivo",
"goTo": "Ir a",
"imageFailedToLoad": "No se puede cargar la imagen",
"saveAs": "Guardar Como",
"somethingWentWrong": "Algo salió mal",
"nextPage": "Página Siguiente",
"selected": "Seleccionado",
"tab": "Tabulador",
"positivePrompt": "Prompt Positivo",
"negativePrompt": "Prompt Negativo",
"error": "Error",
"format": "formato",
"unknown": "Desconocido",
"input": "Entrada",
"nodeEditor": "Editor de nodos",
"template": "Plantilla",
"prevPage": "Página Anterior",
"red": "Rojo",
"alpha": "Transparencia",
"outputs": "Salidas",
"editing": "Editando",
"learnMore": "Aprende más",
"enabled": "Activado",
"disabled": "Desactivado",
"folder": "Carpeta",
"updated": "Actualizado",
"created": "Creado",
"save": "Guardar",
"unknownError": "Error Desconocido",
"blue": "Azul",
"viewingDesc": "Revisar imágenes en una vista de galería grande"
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
},
"gallery": {
"galleryImageSize": "Tamaño de la imagen",
@@ -502,8 +467,7 @@
"about": "Acerca de",
"createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
"mode": "Modo",
"submitSupportTicket": "Enviar Ticket de Soporte"
"mode": "Modo"
},
"nodes": {
"zoomInNodes": "Acercar",
@@ -579,17 +543,5 @@
"layers_one": "Capa",
"layers_many": "Capas",
"layers_other": "Capas"
},
"controlnet": {
"crop": "Cortar",
"delete": "Eliminar",
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
"duplicate": "Duplicar",
"colorMapDescription": "Genera un mapa de color desde la imagen",
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
"balanced": "Equilibrado",
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
"detectResolution": "Detectar resolución",
"beginEndStepPercentShort": "Inicio / Final %"
}
}

View File

@@ -45,7 +45,7 @@
"outputs": "Risultati",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
"copyError": "Errore $t(gallery.copy)",
"copyError": "$t(gallery.copy) Errore",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto",
@@ -85,11 +85,7 @@
"viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica",
"editingDesc": "Modifica nell'area Livelli di controllo",
"enabled": "Abilitato",
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta"
"editingDesc": "Modifica nell'area Livelli di controllo"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -126,30 +122,14 @@
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadFailed": "Scaricamento fallito",
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
"openInViewer": "Apri nel visualizzatore",
"selectForCompare": "Seleziona per il confronto",
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
"slider": "Cursore",
"sideBySide": "Fianco a Fianco",
"compareImage": "Immagine di confronto",
"viewerImage": "Immagine visualizzata",
"hover": "Al passaggio del mouse",
"swapImages": "Scambia le immagini",
"compareOptions": "Opzioni di confronto",
"stretchToFit": "Scala per adattare",
"exitCompare": "Esci dal confronto",
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
"appHotkeys": "Applicazione",
"generalHotkeys": "Generale",
"galleryHotkeys": "Galleria",
"unifiedCanvasHotkeys": "Tela",
"unifiedCanvasHotkeys": "Tela Unificata",
"invoke": {
"title": "Invoke",
"desc": "Genera un'immagine"
@@ -167,8 +147,8 @@
"desc": "Apre e chiude il pannello delle opzioni"
},
"pinOptions": {
"title": "Fissa le opzioni",
"desc": "Fissa il pannello delle opzioni"
"title": "Appunta le opzioni",
"desc": "Blocca il pannello delle opzioni"
},
"toggleGallery": {
"title": "Attiva/disattiva galleria",
@@ -352,14 +332,14 @@
"title": "Annulla e cancella"
},
"resetOptionsAndGallery": {
"title": "Ripristina le opzioni e la galleria",
"desc": "Reimposta i pannelli delle opzioni e della galleria"
"title": "Ripristina Opzioni e Galleria",
"desc": "Reimposta le opzioni e i pannelli della galleria"
},
"searchHotkeys": "Cerca tasti di scelta rapida",
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
"toggleOptionsAndGallery": {
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
"title": "Attiva/disattiva le opzioni e la galleria"
"title": "Attiva/disattiva le Opzioni e la Galleria"
},
"clearSearch": "Cancella ricerca",
"remixImage": {
@@ -368,7 +348,7 @@
},
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
}
},
"modelManager": {
@@ -398,7 +378,7 @@
"convertToDiffusers": "Converti in Diffusori",
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
"modelConverted": "Modello convertito",
"alpha": "Alpha",
@@ -548,7 +528,7 @@
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
@@ -626,25 +606,25 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato alla Tela",
"parametersNotSet": "Parametri non richiamati",
"parametersNotSet": "Parametri non impostati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",
"connected": "Connesso al server",
"connected": "Connesso al Server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"problemCopyingImage": "Impossibile copiare l'immagine",
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"canvasCopiedClipboard": "Tela copiata negli appunti",
"maskSavedAssets": "Maschera salvata nelle risorse",
"problemDownloadingCanvas": "Problema durante lo scarico della tela",
"problemDownloadingCanvas": "Problema durante il download della tela",
"problemMergingCanvas": "Problema nell'unione delle tele",
"imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca",
@@ -678,17 +658,7 @@
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"parameters": "Parametri",
"parameterSetDesc": "{{parameter}} richiamato",
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
"parametersSet": "Parametri richiamati",
"errorCopied": "Errore copiato",
"outOfMemoryError": "Errore di memoria esaurita",
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
"parameters": "Parametri"
},
"tooltip": {
"feature": {
@@ -704,7 +674,7 @@
"layer": "Livello",
"base": "Base",
"mask": "Maschera",
"maskingOptions": "Opzioni maschera",
"maskingOptions": "Opzioni di mascheramento",
"enableMask": "Abilita maschera",
"preserveMaskedArea": "Mantieni area mascherata",
"clearMask": "Cancella maschera (Shift+C)",
@@ -775,8 +745,7 @@
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema",
"about": "Informazioni",
"submitSupportTicket": "Invia ticket di supporto"
"about": "Informazioni"
},
"nodes": {
"zoomOutNodes": "Rimpicciolire",
@@ -821,7 +790,7 @@
"workflowNotes": "Note",
"versionUnknown": " Versione sconosciuta",
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
"updateApp": "Aggiorna Applicazione",
"updateApp": "Aggiorna App",
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
"updateNode": "Aggiorna nodo",
"version": "Versione",
@@ -913,14 +882,11 @@
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
"singleFieldType": "{{name}} (Singola)"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
"cancel": "Annulla",
"addBoard": "Aggiungi Bacheca",
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
@@ -932,7 +898,7 @@
"myBoard": "Bacheca",
"searchBoard": "Cerca bacheche ...",
"noMatching": "Nessuna bacheca corrispondente",
"selectBoard": "Seleziona una bacheca",
"selectBoard": "Seleziona una Bacheca",
"uncategorized": "Non categorizzato",
"downloadBoard": "Scarica la bacheca",
"deleteBoardOnly": "solo la Bacheca",
@@ -953,7 +919,7 @@
"control": "Controllo",
"crop": "Ritaglia",
"depthMidas": "Profondità (Midas)",
"detectResolution": "Rileva la risoluzione",
"detectResolution": "Rileva risoluzione",
"controlMode": "Modalità di controllo",
"cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)",
@@ -964,7 +930,7 @@
"showAdvanced": "Mostra opzioni Avanzate",
"bgth": "Soglia rimozione sfondo",
"importImageFromCanvas": "Importa immagine dalla Tela",
"lineartDescription": "Converte l'immagine in linea",
"lineartDescription": "Converte l'immagine in lineart",
"importMaskFromCanvas": "Importa maschera dalla Tela",
"hideAdvanced": "Nascondi opzioni avanzate",
"resetControlImage": "Reimposta immagine di controllo",
@@ -980,7 +946,7 @@
"pidiDescription": "Elaborazione immagini PIDI",
"fill": "Riempie",
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
"lineartAnimeDescription": "Elaborazione linea in stile anime",
"lineartAnimeDescription": "Elaborazione lineart in stile anime",
"imageResolution": "Risoluzione dell'immagine",
"colorMap": "Colore",
"lowThreshold": "Soglia inferiore",

View File

@@ -87,11 +87,7 @@
"viewing": "Просмотр",
"editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи",
"editingDesc": "Редактировать на холсте слоёв управления",
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
"editingDesc": "Редактировать на холсте слоёв управления"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -128,23 +124,7 @@
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
"openInViewer": "Открыть в просмотрщике",
"selectForCompare": "Выбрать для сравнения",
"hover": "Наведение",
"swapImages": "Поменять местами",
"stretchToFit": "Растягивание до нужного размера",
"exitCompare": "Выйти из сравнения",
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
"compareImage": "Сравнить изображение",
"viewerImage": "Изображение просмотрщика",
"selectAnImageToCompare": "Выберите изображение для сравнения",
"slider": "Слайдер",
"sideBySide": "Бок о бок",
"compareOptions": "Варианты сравнения",
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -548,20 +528,7 @@
"missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в",
"invoke": "Создать",
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
"layer": {
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
"ipAdapterNoModelSelected": "IP адаптер не выбран",
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
"initialImageNoImageSelected": "начальное изображение не выбрано",
"rgNoRegion": "регион не выбран",
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
}
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
},
"isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@@ -639,12 +606,12 @@
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"problemCopyingImage": "Не удается скопировать изображение",
"baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
"baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
"imageSavingFailed": "Не удалось сохранить изображение",
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
@@ -685,17 +652,7 @@
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"parameters": "Параметры",
"parameterSetDesc": "Задан {{parameter}}",
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
"baseModelChanged": "Базовая модель сменена",
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
"parametersSet": "Параметры заданы",
"errorCopied": "Ошибка скопирована",
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
"parameters": "Параметры"
},
"tooltip": {
"feature": {
@@ -782,8 +739,7 @@
"loadMore": "Загрузить больше",
"resetUI": "$t(accessibility.reset) интерфейс",
"createIssue": "Сообщить о проблеме",
"about": "Об этом",
"submitSupportTicket": "Отправить тикет в службу поддержки"
"about": "Об этом"
},
"nodes": {
"zoomInNodes": "Увеличьте масштаб",
@@ -876,7 +832,7 @@
"workflowName": "Название",
"collection": "Коллекция",
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
"collectionFieldType": "{{name}} (Коллекция)",
"collectionFieldType": "Коллекция {{name}}",
"workflowNotes": "Примечания",
"string": "Строка",
"unknownNodeType": "Неизвестный тип узла",
@@ -892,7 +848,7 @@
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
"collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
"nodeVersion": "Версия узла",
"loadingNodes": "Загрузка узлов...",
@@ -914,16 +870,7 @@
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График",
"showEdgeLabels": "Показать метки на ребрах",
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
"missingNode": "Отсутствует узел вызова",
"missingInvocationTemplate": "Отсутствует шаблон вызова",
"missingFieldTemplate": "Отсутствующий шаблон поля",
"singleFieldType": "{{name}} (Один)",
"noGraph": "Нет графика",
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
},
"controlnet": {
"amult": "a_mult",
@@ -1494,16 +1441,7 @@
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
"item": "Элемент",
"graphFailedToQueue": "Не удалось поставить график в очередь",
"openQueue": "Открыть очередь",
"prompts_one": "Запрос",
"prompts_few": "Запроса",
"prompts_many": "Запросов",
"iterations_one": "Итерация",
"iterations_few": "Итерации",
"iterations_many": "Итераций",
"generations_one": "Генерация",
"generations_few": "Генерации",
"generations_many": "Генераций"
"openQueue": "Открыть очередь"
},
"sdxl": {
"refinerStart": "Запуск доработчика",

View File

@@ -1,6 +1,6 @@
{
"common": {
"nodes": "工作流程",
"nodes": "節點",
"img2img": "圖片轉圖片",
"statusDisconnected": "已中斷連線",
"back": "返回",
@@ -11,239 +11,17 @@
"reportBugLabel": "回報錯誤",
"githubLabel": "GitHub",
"hotkeysLabel": "快捷鍵",
"languagePickerLabel": "語言",
"languagePickerLabel": "切換語言",
"unifiedCanvas": "統一畫布",
"cancel": "取消",
"txt2img": "文字轉圖片",
"controlNet": "ControlNet",
"advanced": "進階",
"folder": "資料夾",
"installed": "已安裝",
"accept": "接受",
"goTo": "前往",
"input": "輸入",
"random": "隨機",
"selected": "已選擇",
"communityLabel": "社群",
"loading": "載入中",
"delete": "刪除",
"copy": "複製",
"error": "錯誤",
"file": "檔案",
"format": "格式",
"imageFailedToLoad": "無法載入圖片"
"txt2img": "文字轉圖片"
},
"accessibility": {
"invokeProgressBar": "Invoke 進度條",
"uploadImage": "上傳圖片",
"reset": "重",
"reset": "重",
"nextImage": "下一張圖片",
"previousImage": "上一張圖片",
"menu": "選單",
"loadMore": "載入更多",
"about": "關於",
"createIssue": "建立問題",
"resetUI": "$t(accessibility.reset) 介面",
"submitSupportTicket": "提交支援工單",
"mode": "模式"
},
"boards": {
"loading": "載入中…",
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
"move": "移動",
"uncategorized": "未分類",
"cancel": "取消"
},
"metadata": {
"workflow": "工作流程",
"steps": "步數",
"model": "模型",
"seed": "種子",
"vae": "VAE",
"seamless": "無縫",
"metadata": "元數據",
"width": "寬度",
"height": "高度"
},
"accordions": {
"control": {
"title": "控制"
},
"compositing": {
"title": "合成"
},
"advanced": {
"title": "進階",
"options": "$t(accordions.advanced.title) 選項"
}
},
"hotkeys": {
"nodesHotkeys": "節點",
"cancel": {
"title": "取消"
},
"generalHotkeys": "一般",
"keyboardShortcuts": "快捷鍵",
"appHotkeys": "應用程式"
},
"modelManager": {
"advanced": "進階",
"allModels": "全部模型",
"variant": "變體",
"config": "配置",
"model": "模型",
"selected": "已選擇",
"huggingFace": "HuggingFace",
"install": "安裝",
"metadata": "元數據",
"delete": "刪除",
"description": "描述",
"cancel": "取消",
"convert": "轉換",
"manual": "手動",
"none": "無",
"name": "名稱",
"load": "載入",
"height": "高度",
"width": "寬度",
"search": "搜尋",
"vae": "VAE",
"settings": "設定"
},
"controlnet": {
"mlsd": "M-LSD",
"canny": "Canny",
"duplicate": "重複",
"none": "無",
"pidi": "PIDI",
"h": "H",
"balanced": "平衡",
"crop": "裁切",
"processor": "處理器",
"control": "控制",
"f": "F",
"lineart": "線條藝術",
"w": "W",
"hed": "HED",
"delete": "刪除"
},
"queue": {
"queue": "佇列",
"canceled": "已取消",
"failed": "已失敗",
"completed": "已完成",
"cancel": "取消",
"session": "工作階段",
"batch": "批量",
"item": "項目",
"completedIn": "完成於",
"notReady": "無法排隊"
},
"parameters": {
"cancel": {
"cancel": "取消"
},
"height": "高度",
"type": "類型",
"symmetry": "對稱性",
"images": "圖片",
"width": "寬度",
"coherenceMode": "模式",
"seed": "種子",
"general": "一般",
"strength": "強度",
"steps": "步數",
"info": "資訊"
},
"settings": {
"beta": "Beta",
"developer": "開發者",
"general": "一般",
"models": "模型"
},
"popovers": {
"paramModel": {
"heading": "模型"
},
"compositingCoherenceMode": {
"heading": "模式"
},
"paramSteps": {
"heading": "步數"
},
"controlNetProcessor": {
"heading": "處理器"
},
"paramVAE": {
"heading": "VAE"
},
"paramHeight": {
"heading": "高度"
},
"paramSeed": {
"heading": "種子"
},
"paramWidth": {
"heading": "寬度"
},
"refinerSteps": {
"heading": "步數"
}
},
"unifiedCanvas": {
"undo": "復原",
"mask": "遮罩",
"eraser": "橡皮擦",
"antialiasing": "抗鋸齒",
"redo": "重做",
"layer": "圖層",
"accept": "接受",
"brush": "刷子",
"move": "移動",
"brushSize": "大小"
},
"nodes": {
"workflowName": "名稱",
"notes": "註釋",
"workflowVersion": "版本",
"workflowNotes": "註釋",
"executionStateError": "錯誤",
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
"integer": "整數",
"workflow": "工作流程",
"enum": "枚舉",
"edit": "編輯",
"string": "字串",
"workflowTags": "標籤",
"node": "節點",
"boolean": "布林值",
"workflowAuthor": "作者",
"version": "版本",
"executionStateCompleted": "已完成",
"edge": "邊緣",
"versionUnknown": " 版本未知"
},
"sdxl": {
"steps": "步數",
"loading": "載入中…",
"refiner": "精煉器"
},
"gallery": {
"copy": "複製",
"download": "下載",
"loading": "載入中"
},
"ui": {
"tabs": {
"models": "模型",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"queue": "佇列"
}
},
"models": {
"loading": "載入中"
},
"workflows": {
"name": "名稱"
"menu": "選單"
}
}

View File

@@ -19,13 +19,6 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return extendTheme({
..._theme,
direction,
shadows: {
..._theme.shadows,
selectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
hoverSelectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
},
});
}, [direction]);

View File

@@ -13,6 +13,7 @@ import {
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
@@ -138,7 +139,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// We still have to check the output type
assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
isImageOutput(invocationCompleteAction.payload.data.result),
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;

View File

@@ -9,6 +9,7 @@ import {
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@@ -73,7 +74,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
);
// We still have to check the output type
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received

View File

@@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
@@ -11,7 +11,6 @@ export const galleryImageClicked = createAction<{
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
@@ -29,7 +28,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
startAppListening({
actionCreator: galleryImageClicked,
effect: async (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
@@ -42,13 +41,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageDTO));
}
} else if (shiftKey) {
if (shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);

View File

@@ -14,8 +14,7 @@ import {
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -31,9 +30,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
effect: async (action, { dispatch, getState }) => {
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
@@ -54,7 +50,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
@@ -187,18 +182,24 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
}
/**
* Image selected for compare
* TODO
* Image selection dropped on node image collection field
*/
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
// if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO
// ) {
// const { fieldName, nodeId } = overData.context;
// dispatch(
// fieldValueChanged({
// nodeId,
// fieldName,
// value: [activeData.payload.imageDTO],
// })
// );
// return;
// }
/**
* Image dropped on user board

View File

@@ -11,6 +11,7 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
@@ -32,7 +33,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { result, invocation_source_id } = data;
// This complete event has an associated image output
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image;
const { canvas, gallery } = getState();

View File

@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
@@ -65,7 +65,9 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
});
}
$needsFit.set(true);
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions

View File

@@ -35,7 +35,6 @@ type IAIDndImageProps = FlexProps & {
draggableData?: TypesafeDraggableData;
dropLabel?: ReactNode;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
@@ -62,7 +61,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
@@ -167,11 +165,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
<SelectionOverlay
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
isHovered={withHoverOverlay ? isHovered : false}
/>
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
</Flex>
)}
{!imageDTO && !isUploadDisabled && (

View File

@@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);

View File

@@ -3,17 +3,10 @@ import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isSelectedForCompare: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelectedForCompare && isHovered) {
return 'hoverSelectedForCompare';
}
if (isSelectedForCompare && !isHovered) {
return 'selectedForCompare';
}
if (isSelected && isHovered) {
return 'hoverSelected';
}
@@ -24,7 +17,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
return 'hoverUnselected';
}
return undefined;
}, [isHovered, isSelected, isSelectedForCompare]);
}, [isHovered, isSelected]);
return (
<Box
className="selection-box"
@@ -34,7 +27,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
bottom={0}
insetInlineStart={0}
borderRadius="base"
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
opacity={isSelected ? 1 : 0.7}
transitionProperty="common"
transitionDuration="0.1s"
pointerEvents="none"

View File

@@ -1,21 +0,0 @@
import { useCallback, useMemo, useState } from 'react';
export const useBoolean = (initialValue: boolean) => {
const [isTrue, set] = useState(initialValue);
const setTrue = useCallback(() => set(true), []);
const setFalse = useCallback(() => set(false), []);
const toggle = useCallback(() => set((v) => !v), []);
const api = useMemo(
() => ({
isTrue,
set,
setTrue,
setFalse,
toggle,
}),
[isTrue, set, setTrue, setFalse, toggle]
);
return api;
};

View File

@@ -1,7 +1,3 @@
export const stopPropagation = (e: React.MouseEvent) => {
e.stopPropagation();
};
export const preventDefault = (e: React.MouseEvent) => {
e.preventDefault();
};

View File

@@ -1,13 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type {
AnyInvocation,
BaseModelType,
ControlNetModelConfig,
ImageDTO,
T2IAdapterModelConfig,
} from 'services/api/types';
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
@@ -153,7 +147,7 @@ const zBeginEndStepPct = z
const zControlAdapterBase = z.object({
id: zId,
weight: z.number().gte(-1).lte(2),
weight: z.number().gte(0).lte(1),
image: zImageWithDims.nullable(),
processedImage: zImageWithDims.nullable(),
processorConfig: zProcessorConfig.nullable(),
@@ -189,7 +183,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
export const zIPAdapterConfigV2 = z.object({
id: zId,
type: z.literal('ip_adapter'),
weight: z.number().gte(-1).lte(2),
weight: z.number().gte(0).lte(1),
method: zIPMethodV2,
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
@@ -222,7 +216,10 @@ type ProcessorData<T extends ProcessorTypeV2> = {
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
buildNode(
image: ImageWithDims,
config: Extract<ProcessorConfig, { type: T }>
): Extract<Graph['nodes'][string], { type: T }>;
};
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);

View File

@@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
export const STAGE_BG_DATAURL =
const STAGE_BG_DATAURL =
'';
const mapId = (object: { id: string }) => object.id;

View File

@@ -18,7 +18,7 @@ type BaseDropData = {
id: string;
};
export type CurrentImageDropData = BaseDropData & {
type CurrentImageDropData = BaseDropData & {
actionType: 'SET_CURRENT_IMAGE';
};
@@ -79,14 +79,6 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type SelectForCompareDropData = BaseDropData & {
actionType: 'SELECT_FOR_COMPARE';
context: {
firstImageName?: string | null;
secondImageName?: string | null;
};
};
export type TypesafeDroppableData =
| CurrentImageDropData
| ControlAdapterDropData
@@ -97,8 +89,7 @@ export type TypesafeDroppableData =
| CALayerImageDropData
| IPALayerImageDropData
| RGLayerIPAdapterImageDropData
| IILayerImageDropData
| SelectForCompareDropData;
| IILayerImageDropData;
type BaseDragData = {
id: string;
@@ -143,7 +134,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
over: TypesafeOver | null;
};
interface TypesafeActive extends Omit<Active, 'data'> {
export interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}

View File

@@ -1,14 +1,14 @@
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
if (!overData || !activeData) {
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = activeData;
const { payloadType } = active.data.current;
if (overData.id === activeData.id) {
if (overData.id === active.data.current.id) {
return false;
}
@@ -29,8 +29,6 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SELECT_FOR_COMPARE':
return payloadType === 'IMAGE_DTO';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
@@ -42,7 +40,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = activeData.payload;
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
@@ -51,7 +49,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
if (payloadType === 'GALLERY_SELECTION') {
// Assume all images are on the same board - this is true for the moment
const currentBoard = activeData.payload.boardId;
const currentBoard = active.data.current.payload.boardId;
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
@@ -69,14 +67,14 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = activeData.payload;
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
return currentBoard !== 'none';
}
if (payloadType === 'GALLERY_SELECTION') {
const currentBoard = activeData.payload.boardId;
const currentBoard = active.data.current.payload.boardId;
return currentBoard !== 'none';
}

View File

@@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
</Flex>
)}
{isSelectedForAutoAdd && <AutoAddIcon />}
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<Flex
position="absolute"
bottom={0}

View File

@@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
>
{boardName}
</Flex>
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
</Flex>
</Tooltip>

View File

@@ -10,7 +10,6 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@@ -28,7 +27,6 @@ import {
PiDownloadSimpleBold,
PiFlowArrowBold,
PiFoldersBold,
PiImagesBold,
PiPlantBold,
PiQuotesBold,
PiShareFatBold,
@@ -46,7 +44,6 @@ type SingleSelectionMenuItemsProps = {
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { imageDTO } = props;
const optimalDimension = useAppSelector(selectOptimalDimension);
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isCanvasEnabled = useFeatureStatus('canvas');
@@ -120,10 +117,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
downloadImage(imageDTO.image_url, imageDTO.image_name);
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
const handleSelectImageForCompare = useCallback(() => {
dispatch(imageToCompareChanged(imageDTO));
}, [dispatch, imageDTO]);
return (
<>
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
@@ -137,9 +130,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
{t('gallery.selectForCompare')}
</MenuItem>
<MenuDivider />
<MenuItem
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}

View File

@@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
@@ -46,7 +46,6 @@ const GalleryImage = (props: HoverableImageProps) => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
@@ -106,7 +105,6 @@ const GalleryImage = (props: HoverableImageProps) => {
const onDoubleClick = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
dispatch(imageToCompareChanged(null));
}, [dispatch]);
const handleMouseOut = useCallback(() => {
@@ -154,7 +152,6 @@ const GalleryImage = (props: HoverableImageProps) => {
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
minSize={0}
imageSx={imageSx}
isDropDisabled={true}

View File

@@ -28,9 +28,7 @@ const ImageMetadataGraphTabContent = ({ image }: Props) => {
return <IAINoContentFallback label={t('nodes.noGraph')} />;
}
return (
<DataViewer fileName={`${image.image_name.replace('.png', '')}_graph`} data={graph} label={t('nodes.graph')} />
);
return <DataViewer data={graph} label={t('nodes.graph')} />;
};
export default memo(ImageMetadataGraphTabContent);

View File

@@ -68,22 +68,14 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
</TabPanel>
<TabPanel>
{metadata ? (
<DataViewer
fileName={`${image.image_name.replace('.png', '')}_metadata`}
data={metadata}
label={t('metadata.metadata')}
/>
<DataViewer data={metadata} label={t('metadata.metadata')} />
) : (
<IAINoContentFallback label={t('metadata.noMetaData')} />
)}
</TabPanel>
<TabPanel>
{image ? (
<DataViewer
fileName={`${image.image_name.replace('.png', '')}_details`}
data={image}
label={t('metadata.imageDetails')}
/>
<DataViewer data={image} label={t('metadata.imageDetails')} />
) : (
<IAINoContentFallback label={t('metadata.noImageDetails')} />
)}

View File

@@ -28,13 +28,7 @@ const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
}
return (
<DataViewer
fileName={`${image.image_name.replace('.png', '')}_workflow`}
data={workflow}
label={t('metadata.workflow')}
/>
);
return <DataViewer data={workflow} label={t('metadata.workflow')} />;
};
export default memo(ImageMetadataWorkflowTabContent);

View File

@@ -1,140 +0,0 @@
import {
Button,
ButtonGroup,
Flex,
Icon,
IconButton,
Kbd,
ListItem,
Tooltip,
UnorderedList,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeChanged,
comparisonModeCycled,
imageToCompareChanged,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
export const CompareToolbar = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const setComparisonModeSlider = useCallback(() => {
dispatch(comparisonModeChanged('slider'));
}, [dispatch]);
const setComparisonModeSideBySide = useCallback(() => {
dispatch(comparisonModeChanged('side-by-side'));
}, [dispatch]);
const setComparisonModeHover = useCallback(() => {
dispatch(comparisonModeChanged('hover'));
}, [dispatch]);
const swapImages = useCallback(() => {
dispatch(comparedImagesSwapped());
}, [dispatch]);
useHotkeys('c', swapImages, [swapImages]);
const toggleComparisonFit = useCallback(() => {
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
}, [dispatch, comparisonFit]);
const exitCompare = useCallback(() => {
dispatch(imageToCompareChanged(null));
}, [dispatch]);
useHotkeys('esc', exitCompare, [exitCompare]);
const nextMode = useCallback(() => {
dispatch(comparisonModeCycled());
}, [dispatch]);
useHotkeys('m', nextMode, [nextMode]);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<IconButton
icon={<PiSwapBold />}
aria-label={`${t('gallery.swapImages')} (C)`}
tooltip={`${t('gallery.swapImages')} (C)`}
onClick={swapImages}
/>
{comparisonMode !== 'side-by-side' && (
<IconButton
aria-label={t('gallery.stretchToFit')}
tooltip={t('gallery.stretchToFit')}
onClick={toggleComparisonFit}
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
variant="outline"
icon={<PiArrowsOutBold />}
/>
)}
</Flex>
</Flex>
<Flex flex={1} gap={4} justifyContent="center">
<ButtonGroup variant="outline">
<Button
flexShrink={0}
onClick={setComparisonModeSlider}
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
>
{t('gallery.slider')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeSideBySide}
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
>
{t('gallery.sideBySide')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeHover}
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
>
{t('gallery.hover')}
</Button>
</ButtonGroup>
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center">
<Tooltip label={<CompareHelp />}>
<Flex alignItems="center">
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
</Flex>
</Tooltip>
<IconButton
icon={<PiXBold />}
aria-label={`${t('gallery.exitCompare')} (Esc)`}
tooltip={`${t('gallery.exitCompare')} (Esc)`}
onClick={exitCompare}
/>
</Flex>
</Flex>
</Flex>
);
});
CompareToolbar.displayName = 'CompareToolbar';
const CompareHelp = () => {
return (
<UnorderedList>
<ListItem>
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
</UnorderedList>
);
};

View File

@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { TypesafeDraggableData } from 'features/dnd/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
@@ -22,7 +22,21 @@ const selectLastSelectedImageName = createSelector(
(lastSelectedImage) => lastSelectedImage?.image_name
);
const CurrentImagePreview = () => {
type Props = {
isDragDisabled?: boolean;
isDropDisabled?: boolean;
withNextPrevButtons?: boolean;
withMetadata?: boolean;
alwaysShowProgress?: boolean;
};
const CurrentImagePreview = ({
isDragDisabled = false,
isDropDisabled = false,
withNextPrevButtons = true,
withMetadata = true,
alwaysShowProgress = false,
}: Props) => {
const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName);
@@ -41,6 +55,14 @@ const CurrentImagePreview = () => {
}
}, [imageDTO]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
}),
[]
);
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
const timeoutId = useRef(0);
@@ -64,27 +86,30 @@ const CurrentImagePreview = () => {
justifyContent="center"
position="relative"
>
{hasDenoiseProgress && shouldShowProgressInViewer ? (
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
<ProgressImage />
) : (
<IAIDndImage
imageDTO={imageDTO}
droppableData={droppableData}
draggableData={draggableData}
isDropDisabled={true}
isDragDisabled={isDragDisabled}
isDropDisabled={isDropDisabled}
isUploadDisabled={true}
fitContainer
useThumbailFallback
dropLabel={t('gallery.setCurrentImage')}
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
dataTestId="image-preview"
/>
)}
{shouldShowImageDetails && imageDTO && (
{shouldShowImageDetails && imageDTO && withMetadata && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />
</Box>
)}
<AnimatePresence>
{shouldShowNextPrevButtons && imageDTO && (
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
<Box
as={motion.div}
key="nextPrevButtons"

View File

@@ -1,41 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesBold } from 'react-icons/pi';
type Props = {
containerDims: Dimensions;
};
export const ImageComparison = memo(({ containerDims }: Props) => {
const { t } = useTranslation();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
if (!firstImage || !secondImage) {
// Should rarely/never happen - we don't render this component unless we have images to compare
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
}
if (comparisonMode === 'slider') {
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
if (comparisonMode === 'side-by-side') {
return (
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
);
}
if (comparisonMode === 'hover') {
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
});
ImageComparison.displayName = 'ImageComparison';

View File

@@ -1,47 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDroppable from 'common/components/IAIDroppable';
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { selectComparisonImages } from './common';
const setCurrentImageDropData: CurrentImageDropData = {
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
};
export const ImageComparisonDroppable = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
() => ({
id: 'image-comparison',
actionType: 'SELECT_FOR_COMPARE',
context: {
firstImageName: firstImage?.image_name,
secondImageName: secondImage?.image_name,
},
}),
[firstImage?.image_name, secondImage?.image_name]
);
if (!imageViewer.isOpen) {
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
</Flex>
);
}
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
</Flex>
);
});
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';

View File

@@ -1,117 +0,0 @@
import { Box, Flex, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useMemo, useRef } from 'react';
import type { ComparisonProps } from './common';
import { fitDimsToContainer, getSecondImageDims } from './common';
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const imageContainerRef = useRef<HTMLDivElement>(null);
const mouseOver = useBoolean(false);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-hover-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Image
id="image-comparison-hover-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
<Box
id="image-comparison-hover-second-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
overflow="hidden"
opacity={mouseOver.isTrue ? 1 : 0}
transitionDuration="0.2s"
transitionProperty="common"
>
<Box
id="image-comparison-hover-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-hover-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
</Box>
<Box
id="image-comparison-hover-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseOver={mouseOver.setTrue}
onMouseOut={mouseOver.setFalse}
onContextMenu={preventDefault}
userSelect="none"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonHover.displayName = 'ImageComparisonHover';

View File

@@ -1,33 +0,0 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { DROP_SHADOW } from './common';
type Props = TextProps & {
type: 'first' | 'second';
};
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
const { t } = useTranslation();
return (
<Text
position="absolute"
bottom={4}
insetInlineEnd={type === 'first' ? undefined : 4}
insetInlineStart={type === 'first' ? 4 : undefined}
textOverflow="clip"
whiteSpace="nowrap"
filter={DROP_SHADOW}
color="base.50"
transitionDuration="0.2s"
transitionProperty="common"
{...rest}
>
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
</Text>
);
});
ImageComparisonLabel.displayName = 'ImageComparisonLabel';

View File

@@ -1,70 +0,0 @@
import { Flex, Image } from '@invoke-ai/ui-library';
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo, useCallback, useRef } from 'react';
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
import { Panel, PanelGroup } from 'react-resizable-panels';
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const onDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) {
return;
}
panelGroupRef.current.setLayout([50, 50]);
}, []);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={firstImage.width}
h={firstImage.height}
maxW="full"
maxH="full"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="first" />
</Flex>
</Flex>
</Panel>
<ResizeHandle
id="image-comparison-side-by-side-handle"
onDoubleClick={onDoubleClickHandle}
orientation="vertical"
/>
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={secondImage.width}
h={secondImage.height}
maxW="full"
maxH="full"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="second" />
</Flex>
</Flex>
</Panel>
</PanelGroup>
</Flex>
</Flex>
);
});
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';

View File

@@ -1,215 +0,0 @@
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import type { ComparisonProps } from './common';
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
const INITIAL_POS = '50%';
const HANDLE_WIDTH = 2;
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
const HANDLE_HITBOX = 20;
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
// How wide the first image is
const [width, setWidth] = useState(INITIAL_POS);
const handleRef = useRef<HTMLDivElement>(null);
// To manage aspect ratios, we need to know the size of the container
const imageContainerRef = useRef<HTMLDivElement>(null);
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
const rafRef = useRef<number | null>(null);
const lastMoveTimeRef = useRef<number>(0);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
const updateHandlePos = useCallback((clientX: number) => {
if (!handleRef.current || !imageContainerRef.current) {
return;
}
lastMoveTimeRef.current = performance.now();
const { x, width } = imageContainerRef.current.getBoundingClientRect();
const rawHandlePos = ((clientX - x) * 100) / width;
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
setWidth(`${newHandlePos}%`);
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
}, []);
const onMouseMove = useCallback(
(e: MouseEvent) => {
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
rafRef.current = window.requestAnimationFrame(() => {
updateHandlePos(e.clientX);
rafRef.current = null;
});
}
},
[updateHandlePos]
);
const onMouseUp = useCallback(() => {
window.removeEventListener('mousemove', onMouseMove);
}, [onMouseMove]);
const onMouseDown = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
// Update the handle position immediately on click
updateHandlePos(e.clientX);
window.addEventListener('mouseup', onMouseUp, { once: true });
window.addEventListener('mousemove', onMouseMove);
},
[onMouseMove, onMouseUp, updateHandlePos]
);
useEffect(
() => () => {
if (rafRef.current !== null) {
cancelAnimationFrame(rafRef.current);
}
},
[]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Box
id="image-comparison-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" />
<Box
id="image-comparison-first-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
w={width}
overflow="hidden"
>
<Image
id="image-comparison-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" />
</Box>
<Flex
id="image-comparison-handle"
ref={handleRef}
position="absolute"
top={0}
bottom={0}
left={left}
w={HANDLE_HITBOX_PX}
cursor="ew-resize"
filter={DROP_SHADOW}
opacity={0.8}
color="base.50"
>
<Box
id="image-comparison-handle-divider"
w={HANDLE_WIDTH_PX}
h="full"
bg="currentColor"
shadow="dark-lg"
position="absolute"
top={0}
left={HANDLE_INNER_LEFT_PX}
/>
<Flex
id="image-comparison-handle-icons"
gap={4}
position="absolute"
left="50%"
top="50%"
transform="translate(-50%, 0)"
filter={DROP_SHADOW}
>
<Icon as={PiCaretLeftBold} />
<Icon as={PiCaretRightBold} />
</Flex>
</Flex>
<Box
id="image-comparison-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseDown={onMouseDown}
onContextMenu={preventDefault}
userSelect="none"
cursor="ew-resize"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonSlider.displayName = 'ImageComparisonSlider';

View File

@@ -1,16 +1,36 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo } from 'react';
import { useMeasure } from 'react-use';
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useImageViewer } from './useImageViewer';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
import { ViewerToggleMenu } from './ViewerToggleMenu';
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
export const ImageViewer = memo(() => {
const imageViewer = useImageViewer();
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
const { isOpen, onToggle, onClose } = useImageViewer();
const activeTabName = useAppSelector(activeTabNameSelector);
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
const shouldShowViewer = useMemo(() => {
if (!isViewerEnabled) {
return false;
}
return isOpen;
}, [isOpen, isViewerEnabled]);
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
if (!shouldShowViewer) {
return null;
}
return (
<Flex
@@ -26,13 +46,25 @@ export const ImageViewer = memo(() => {
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
{imageViewer.isComparing && <CompareToolbar />}
{!imageViewer.isComparing && <ViewerToolbar />}
<Box ref={containerRef} w="full" h="full">
{!imageViewer.isComparing && <CurrentImagePreview />}
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
</Box>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
<ViewerToggleMenu />
</Flex>
</Flex>
</Flex>
<CurrentImagePreview />
</Flex>
);
});

View File

@@ -0,0 +1,45 @@
import { Flex } from '@invoke-ai/ui-library';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const ImageViewerWorkflows = memo(() => {
return (
<Flex
layerStyle="first"
borderRadius="base"
position="absolute"
flexDirection="column"
top={0}
right={0}
bottom={0}
left={0}
p={2}
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" />
</Flex>
</Flex>
<CurrentImagePreview />
</Flex>
);
});
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';

View File

@@ -9,35 +9,33 @@ import {
PopoverTrigger,
Text,
} from '@invoke-ai/ui-library';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
import { useImageViewer } from './useImageViewer';
export const ViewerToggleMenu = () => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
const { isOpen, onClose, onOpen } = useImageViewer();
return (
<Popover isLazy>
<PopoverTrigger>
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
<Button variant="outline" data-testid="toggle-viewer-menu-button">
<Flex gap={3} w="full" alignItems="center">
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
<Icon as={PiCaretDownBold} />
</Flex>
</Button>
</PopoverTrigger>
<PopoverContent p={2} pointerEvents="auto">
<PopoverContent p={2}>
<PopoverArrow />
<PopoverBody>
<Flex flexDir="column">
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.viewing')}
@@ -48,9 +46,9 @@ export const ViewerToggleMenu = () => {
</Flex>
</Flex>
</Button>
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.editing')}

View File

@@ -1,33 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import { ViewerToggleMenu } from './ViewerToggleMenu';
export const ViewerToolbar = memo(() => {
const tab = useAppSelector(activeTabNameSelector);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
{tab !== 'workflows' && <ViewerToggleMenu />}
</Flex>
</Flex>
</Flex>
);
});
ViewerToolbar.displayName = 'ViewerToolbar';

View File

@@ -1,64 +0,0 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { ComparisonFit } from 'features/gallery/store/types';
import type { ImageDTO } from 'services/api/types';
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export type ComparisonProps = {
firstImage: ImageDTO;
secondImage: ImageDTO;
containerDims: Dimensions;
};
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
// Fall back to the image's dimensions if the container has no dimensions
if (containerDims.width === 0 || containerDims.height === 0) {
return { width: imageDims.width, height: imageDims.height };
}
// Fall back to the image's dimensions if the image fits within the container
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
return { width: imageDims.width, height: imageDims.height };
}
const targetAspectRatio = containerDims.width / containerDims.height;
const imageAspectRatio = imageDims.width / imageDims.height;
let width: number;
let height: number;
if (imageAspectRatio > targetAspectRatio) {
// Image is wider than container's aspect ratio
width = containerDims.width;
height = width / imageAspectRatio;
} else {
// Image is taller than container's aspect ratio
height = containerDims.height;
width = height * imageAspectRatio;
}
return { width, height };
};
/**
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
*/
export const getSecondImageDims = (
comparisonFit: ComparisonFit,
fittedDims: Dimensions,
firstImageDims: Dimensions,
secondImageDims: Dimensions
): Dimensions => {
const width =
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
const height =
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
return { width, height };
};
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
const secondImage = gallerySlice.imageToCompare;
return { firstImage, secondImage };
});

View File

@@ -1,31 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(false));
}
}, [dispatch, isComparing, isOpen]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(!isOpen));
}
}, [dispatch, isComparing, isOpen]);
return { isOpen, onOpen, onClose, onToggle, isComparing };
};

View File

@@ -0,0 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
dispatch(isImageViewerOpenChanged(false));
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
dispatch(isImageViewerOpenChanged(!isOpen));
}, [dispatch, isOpen]);
return { isOpen, onOpen, onClose, onToggle };
};

View File

@@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
const NextPrevImageButtons = () => {
const { t } = useTranslation();
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const {
areMoreImagesAvailable,
@@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.previousImage')}
icon={<PiCaretLeftBold size={64} />}
variant="unstyled"
onClick={prevImage}
onClick={handleLeftImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>
@@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.nextImage')}
icon={<PiCaretRightBold size={64} />}
variant="unstyled"
onClick={nextImage}
onClick={handleRightImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>

View File

@@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
useGalleryNavigation();
useHotkeys(
['left', 'alt+left'],
(e) => {
canNavigateGallery && handleLeftImage(e.altKey);
'left',
() => {
canNavigateGallery && handleLeftImage();
},
[handleLeftImage, canNavigateGallery]
);
useHotkeys(
['right', 'alt+right'],
(e) => {
'right',
() => {
if (!canNavigateGallery) {
return;
}
@@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
return;
}
if (!isOnLastImage) {
handleRightImage(e.altKey);
handleRightImage();
}
},
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
);
useHotkeys(
['up', 'alt+up'],
(e) => {
handleUpImage(e.altKey);
'up',
() => {
handleUpImage();
},
{ preventDefault: true },
[handleUpImage]
);
useHotkeys(
['down', 'alt+down'],
(e) => {
'down',
() => {
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
handleDownImage(e.altKey);
handleDownImage();
},
{ preventDefault: true },
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]

View File

@@ -1,11 +1,11 @@
import { useAltModifier } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { getIsVisible } from 'features/gallery/util/getIsVisible';
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
import { clamp } from 'lodash-es';
@@ -106,12 +106,10 @@ const getImageFuncs = {
};
type UseGalleryNavigationReturn = {
handleLeftImage: (alt?: boolean) => void;
handleRightImage: (alt?: boolean) => void;
handleUpImage: (alt?: boolean) => void;
handleDownImage: (alt?: boolean) => void;
prevImage: () => void;
nextImage: () => void;
handleLeftImage: () => void;
handleRightImage: () => void;
handleUpImage: () => void;
handleDownImage: () => void;
isOnFirstImage: boolean;
isOnLastImage: boolean;
areImagesBelowCurrent: boolean;
@@ -125,15 +123,7 @@ type UseGalleryNavigationReturn = {
*/
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
const dispatch = useAppDispatch();
const alt = useAltModifier();
const lastSelectedImage = useAppSelector((s) => {
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
if (alt) {
return s.gallery.imageToCompare ?? lastSelected;
} else {
return lastSelected;
}
});
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const {
queryResult: { data },
} = useGalleryImages();
@@ -146,7 +136,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
}, [lastSelectedImage, data]);
const handleNavigation = useCallback(
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
(direction: 'left' | 'right' | 'up' | 'down') => {
if (!data) {
return;
}
@@ -154,14 +144,10 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
if (!image || index === lastSelectedImageIndex) {
return;
}
if (alt) {
dispatch(imageToCompareChanged(image));
} else {
dispatch(imageSelected(image));
}
dispatch(imageSelected(image));
scrollToImage(image.image_name, index);
},
[data, lastSelectedImageIndex, dispatch]
[dispatch, lastSelectedImageIndex, data]
);
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
@@ -176,41 +162,21 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
}, [lastSelectedImageIndex, loadedImagesCount]);
const handleLeftImage = useCallback(
(alt?: boolean) => {
handleNavigation('left', alt);
},
[handleNavigation]
);
const handleLeftImage = useCallback(() => {
handleNavigation('left');
}, [handleNavigation]);
const handleRightImage = useCallback(
(alt?: boolean) => {
handleNavigation('right', alt);
},
[handleNavigation]
);
const handleRightImage = useCallback(() => {
handleNavigation('right');
}, [handleNavigation]);
const handleUpImage = useCallback(
(alt?: boolean) => {
handleNavigation('up', alt);
},
[handleNavigation]
);
const handleUpImage = useCallback(() => {
handleNavigation('up');
}, [handleNavigation]);
const handleDownImage = useCallback(
(alt?: boolean) => {
handleNavigation('down', alt);
},
[handleNavigation]
);
const nextImage = useCallback(() => {
handleRightImage();
}, [handleRightImage]);
const prevImage = useCallback(() => {
handleLeftImage();
}, [handleLeftImage]);
const handleDownImage = useCallback(() => {
handleNavigation('down');
}, [handleNavigation]);
return {
handleLeftImage,
@@ -220,7 +186,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
isOnFirstImage,
isOnLastImage,
areImagesBelowCurrent,
nextImage,
prevImage,
};
};

View File

@@ -36,7 +36,6 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
shiftKey: e.shiftKey,
ctrlKey: e.ctrlKey,
metaKey: e.metaKey,
altKey: e.altKey,
})
);
},

View File

@@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
import type { BoardId, GalleryState, GalleryView } from './types';
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
const initialGalleryState: GalleryState = {
@@ -22,9 +22,6 @@ const initialGalleryState: GalleryState = {
limit: INITIAL_IMAGE_LIMIT,
offset: 0,
isImageViewerOpen: true,
imageToCompare: null,
comparisonMode: 'slider',
comparisonFit: 'fill',
};
export const gallerySlice = createSlice({
@@ -37,28 +34,6 @@ export const gallerySlice = createSlice({
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
state.selection = uniqBy(action.payload, (i) => i.image_name);
},
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.imageToCompare = action.payload;
if (action.payload) {
state.isImageViewerOpen = true;
}
},
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
state.comparisonMode = action.payload;
},
comparisonModeCycled: (state) => {
switch (state.comparisonMode) {
case 'slider':
state.comparisonMode = 'side-by-side';
break;
case 'side-by-side':
state.comparisonMode = 'hover';
break;
case 'hover':
state.comparisonMode = 'slider';
break;
}
},
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitch = action.payload;
},
@@ -104,16 +79,6 @@ export const gallerySlice = createSlice({
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isImageViewerOpen = action.payload;
},
comparedImagesSwapped: (state) => {
if (state.imageToCompare) {
const oldSelection = state.selection;
state.selection = [state.imageToCompare];
state.imageToCompare = oldSelection[0] ?? null;
}
},
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
state.comparisonFit = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
@@ -152,11 +117,6 @@ export const {
moreImagesLoaded,
alwaysShowImageSizeBadgeChanged,
isImageViewerOpenChanged,
imageToCompareChanged,
comparisonModeChanged,
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeCycled,
} = gallerySlice.actions;
const isAnyBoardDeleted = isAnyOf(
@@ -178,13 +138,5 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
name: gallerySlice.name,
initialState: initialGalleryState,
migrate: migrateGalleryState,
persistDenylist: [
'selection',
'selectedBoardId',
'galleryView',
'offset',
'limit',
'isImageViewerOpen',
'imageToCompare',
],
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
};

View File

@@ -7,8 +7,6 @@ export const IMAGE_LIMIT = 20;
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type GalleryState = {
selection: ImageDTO[];
@@ -22,8 +20,5 @@ export type GalleryState = {
offset: number;
limit: number;
alwaysShowImageSizeBadge: boolean;
imageToCompare: ImageDTO | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
isImageViewerOpen: boolean;
};

View File

@@ -1,7 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
@@ -19,7 +16,6 @@ import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { size } from 'lodash-es';
import { assert } from 'tsafe';
import { parsers } from './parsers';
@@ -380,25 +376,54 @@ export const handlers = {
}),
} as const;
type ParsedValue = Awaited<ReturnType<(typeof handlers)[keyof typeof handlers]['parse']>>;
type RecallResults = Partial<Record<keyof typeof handlers, ParsedValue>>;
export const parseAndRecallPrompts = async (metadata: unknown) => {
const keysToRecall: (keyof typeof handlers)[] = [
'positivePrompt',
'negativePrompt',
'sdxlPositiveStylePrompt',
'sdxlNegativeStylePrompt',
];
const recalled = await recallKeys(keysToRecall, metadata);
if (size(recalled) > 0) {
const results = await Promise.allSettled([
handlers.positivePrompt.parse(metadata).then((positivePrompt) => {
if (!handlers.positivePrompt.recall) {
return;
}
handlers.positivePrompt?.recall(positivePrompt);
}),
handlers.negativePrompt.parse(metadata).then((negativePrompt) => {
if (!handlers.negativePrompt.recall) {
return;
}
handlers.negativePrompt?.recall(negativePrompt);
}),
handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
if (!handlers.sdxlPositiveStylePrompt.recall) {
return;
}
handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
}),
handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
if (!handlers.sdxlNegativeStylePrompt.recall) {
return;
}
handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.allPrompts'));
}
};
export const parseAndRecallImageDimensions = async (metadata: unknown) => {
const recalled = recallKeys(['width', 'height'], metadata);
if (size(recalled) > 0) {
const results = await Promise.allSettled([
handlers.width.parse(metadata).then((width) => {
if (!handlers.width.recall) {
return;
}
handlers.width?.recall(width);
}),
handlers.height.parse(metadata).then((height) => {
if (!handlers.height.recall) {
return;
}
handlers.height?.recall(height);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.imageDimensions'));
}
};
@@ -413,20 +438,28 @@ export const parseAndRecallAllMetadata = async (
toControlLayers: boolean,
skip: (keyof typeof handlers)[] = []
) => {
const skipKeys = deepClone(skip);
const skipKeys = skip ?? [];
if (toControlLayers) {
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
} else {
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
}
const results = await Promise.allSettled(
objectKeys(handlers)
.filter((key) => !skipKeys.includes(key))
.map((key) => {
const { parse, recall } = handlers[key];
return parse(metadata).then((value) => {
if (!recall) {
return;
}
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
recall(value);
});
})
);
// We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
// concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
// results
const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
const recalled = await recallKeys(keysToRecall, metadata);
if (size(recalled) > 0) {
if (results.some((result) => result.status === 'fulfilled')) {
toast({
id: 'PARAMETER_SET',
title: t('toast.parametersSet'),
@@ -440,43 +473,3 @@ export const parseAndRecallAllMetadata = async (
});
}
};
/**
* Recalls a set of keys from metadata.
* Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
* prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
* @param keysToRecall An array of keys to recall.
* @param metadata The metadata to recall from
* @returns A promise that resolves to an object containing the recalled values.
*/
const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise<RecallResults> => {
const { dispatch } = getStore();
const recalled: RecallResults = {};
for (const key of keysToRecall) {
const { parse, recall } = handlers[key];
if (!recall) {
continue;
}
try {
const value = await parse(metadata);
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
await recall(value);
recalled[key] = value;
} catch {
// no-op
}
}
if (
(recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
(recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
) {
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
dispatch(shouldConcatPromptsChanged(false));
} else {
// Otherwise, we should enable prompt concat
dispatch(shouldConcatPromptsChanged(true));
}
return recalled;
};

View File

@@ -1,7 +1,6 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import type { ModelIdentifier } from 'features/nodes/types/v2/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@@ -108,30 +107,19 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
/**
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
* @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch
* the model config from the server to ensure that the model identifier is valid and represents an installed model.
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid.
*/
export const getModelKey = async (
modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier,
type: ModelType,
message?: string
): Promise<string> => {
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
if (isModelIdentifier(modelIdentifier)) {
try {
// Check if the model exists by key
return (await fetchModelConfig(modelIdentifier.key)).key;
} catch {
// If not, fetch the model key by name and base model
return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key;
}
} else if (isModelIdentifierV2(modelIdentifier)) {
// Try by old-format model identifier
return modelIdentifier.key;
}
if (isModelIdentifierV2(modelIdentifier)) {
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
}
// Nope, couldn't find it
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};

View File

@@ -19,7 +19,7 @@ import {
redo,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
@@ -68,7 +68,6 @@ export const Flow = memo(() => {
const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useStore($viewport);
const needsFit = useStore($needsFit);
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
@@ -93,16 +92,8 @@ export const Flow = memo(() => {
const onNodesChange: OnNodesChange = useCallback(
(nodeChanges) => {
dispatch(nodesChanged(nodeChanges));
const flow = $flow.get();
if (!flow) {
return;
}
if (needsFit) {
$needsFit.set(false);
flow.fitView();
}
},
[dispatch, needsFit]
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(

View File

@@ -15,20 +15,27 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.present.shouldShowFieldTypeLegend
// );
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => {
zoomIn({ duration: 300 });
zoomIn();
}, [zoomIn]);
const handleClickedZoomOut = useCallback(() => {
zoomOut({ duration: 300 });
zoomOut();
}, [zoomOut]);
const handleClickedFitView = useCallback(() => {
fitView({ duration: 300 });
fitView();
}, [fitView]);
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
// }, [shouldShowFieldTypeLegend, dispatch]);
const handleClickedToggleMiniMapPanel = useCallback(() => {
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
}, [shouldShowMinimapPanel, dispatch]);

View File

@@ -1,7 +1,9 @@
import 'reactflow/dist/style.css';
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import QueueControls from 'features/queue/components/QueueControls';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
@@ -19,8 +21,14 @@ import { WorkflowName } from './WorkflowName';
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
return {
mode: workflow.mode,
};
});
const NodeEditorPanelGroup = () => {
const mode = useAppSelector((s) => s.workflow.mode);
const { mode } = useAppSelector(selector);
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const panelStorage = usePanelStorage();

View File

@@ -1,12 +1,20 @@
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { ModeToggle } from './ModeToggle';
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
return {
mode: workflow.mode,
};
});
export const WorkflowMenu = () => {
const mode = useAppSelector((s) => s.workflow.mode);
const { mode } = useAppSelector(selector);
return (
<Flex gap="2" alignItems="center">

View File

@@ -11,7 +11,8 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
import type { ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview';
@@ -65,4 +66,4 @@ const InspectorOutputsTab = () => {
export default memo(InspectorOutputsTab);
const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`;

View File

@@ -2,4 +2,3 @@ import { atom } from 'nanostores';
import type { ReactFlowInstance } from 'reactflow';
export const $flow = atom<ReactFlowInstance | null>(null);
export const $needsFit = atom<boolean>(true);

View File

@@ -144,4 +144,5 @@ const zImageOutput = z.object({
type: z.literal('image_output'),
});
export type ImageOutput = z.infer<typeof zImageOutput>;
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
// #endregion

View File

@@ -1,7 +1,8 @@
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import type { Graph } from 'services/api/types';
import type { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
/**

View File

@@ -1,7 +1,6 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { ImageViewerWorkflows } from 'features/gallery/components/ImageViewer/ImageViewerWorkflows';
import NodeEditor from 'features/nodes/components/NodeEditor';
import { memo } from 'react';
import { ReactFlowProvider } from 'reactflow';
@@ -11,8 +10,7 @@ const NodesTab = () => {
if (mode === 'view') {
return (
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
<ImageViewer />
<ImageComparisonDroppable />
<ImageViewerWorkflows />
</Box>
);
}

View File

@@ -1,17 +1,13 @@
import { Box } from '@invoke-ai/ui-library';
import { ControlLayersEditor } from 'features/controlLayers/components/ControlLayersEditor';
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
const TextToImageTab = () => {
const imageViewer = useImageViewer();
return (
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
<ControlLayersEditor />
{imageViewer.isOpen && <ImageViewer />}
<ImageComparisonDroppable />
<ImageViewer />
</Box>
);
};

View File

@@ -41,7 +41,7 @@ const UnifiedCanvasTab = () => {
>
<IAICanvasToolbar />
<IAICanvas />
{isValidDrop(droppableData, active?.data.current) && (
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} label={t('toast.setCanvasInitialImage')} />
)}
</Flex>

File diff suppressed because one or more lines are too long

View File

@@ -122,6 +122,7 @@ export type ModelInstallStatus = S['InstallStatus'];
// Graphs
export type Graph = S['Graph'];
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
export type GraphExecutionState = S['GraphExecutionState'];
export type Batch = S['Batch'];
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
@@ -131,14 +132,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
type KeysOfUnion<T> = T extends T ? keyof T : never;
export type AnyInvocation = Exclude<
NonNullable<S['Graph']['nodes']>[string],
Graph['nodes'][string],
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
>;
export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
export type InvocationType = AnyInvocation['type'];
type InvocationOutputMap = S['InvocationOutputMap'];
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];

View File

@@ -1,12 +1,21 @@
import type { S } from 'services/api/types';
import type { Graph, GraphExecutionState, S } from 'services/api/types';
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
invocation: AnyInvocation;
};
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
result: AnyResult;
invocation: AnyInvocation;
};
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];

View File

@@ -1 +1 @@
__version__ = "4.2.4"
__version__ = "4.2.3"

View File

@@ -55,10 +55,10 @@ dependencies = [
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0",
"fastapi==0.111.0",
"fastapi==0.110.0",
"huggingface-hub==0.23.1",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"pydantic==2.6.3",
"python-socketio==5.11.1",
"uvicorn[standard]==0.28.0",

View File

@@ -7,10 +7,9 @@ def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import app
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.app.api_app import custom_openapi
schema = get_openapi_func(app)()
schema = custom_openapi()
json.dump(schema, sys.stdout, indent=2)

View File

@@ -1,6 +1,5 @@
import pytest
from pydantic import TypeAdapter
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -714,4 +713,4 @@ def test_iterate_accepts_collection():
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
models_json_schema([(Graph, "serialization")])
_ = Graph.model_json_schema()