mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-29 21:27:59 -05:00
Compare commits
64 Commits
feat/metad
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75624e9158 | ||
|
|
a2613948d8 | ||
|
|
f8392b2f78 | ||
|
|
358116bc22 | ||
|
|
1e3590111d | ||
|
|
063b800280 | ||
|
|
3935bf92c8 | ||
|
|
066e09b517 | ||
|
|
869b4a8d49 | ||
|
|
13919ff300 | ||
|
|
634e5652ef | ||
|
|
9bdc718df5 | ||
|
|
73ca8ccdb3 | ||
|
|
f37ffda966 | ||
|
|
5a9777d443 | ||
|
|
8072c05ee0 | ||
|
|
75ff4f4ca3 | ||
|
|
30df123221 | ||
|
|
06193ddbe8 | ||
|
|
ce5122f87c | ||
|
|
43ebd68313 | ||
|
|
ec19fcafb1 | ||
|
|
6fcc7d4c4b | ||
|
|
912087e4dc | ||
|
|
593fb95213 | ||
|
|
6d821b32d3 | ||
|
|
297f96c16b | ||
|
|
0e53b27655 | ||
|
|
35ae9f6e71 | ||
|
|
a1d9e6b871 | ||
|
|
f05379f965 | ||
|
|
e34e6d6e80 | ||
|
|
86cb53342a | ||
|
|
e3de996525 | ||
|
|
25a71a1791 | ||
|
|
d16583ad1c | ||
|
|
46db1dd18f | ||
|
|
4c9344b0ee | ||
|
|
cba31efd78 | ||
|
|
4d01b5c0f2 | ||
|
|
e02af8f518 | ||
|
|
c485cf568b | ||
|
|
51451cbf21 | ||
|
|
0363a06963 | ||
|
|
cc280cbef1 | ||
|
|
7544eadd48 | ||
|
|
7d683b4db6 | ||
|
|
60b3c6a201 | ||
|
|
88c8cb61f0 | ||
|
|
43fbac26df | ||
|
|
627444e17c | ||
|
|
5601858f4f | ||
|
|
b5e1ba34b3 | ||
|
|
58aa159a50 | ||
|
|
d8f7c19030 | ||
|
|
24132a7950 | ||
|
|
45d172d5a8 | ||
|
|
3cb6d333f6 | ||
|
|
4570702dd0 | ||
|
|
1d107f30e5 | ||
|
|
79084e9e20 | ||
|
|
fc9b4539a3 | ||
|
|
09ef57718e | ||
|
|
cab8239ba8 |
@@ -296,8 +296,18 @@ code for InvokeAI. For this to work, you will need to install the
|
||||
on your system, please see the [Git Installation
|
||||
Guide](https://github.com/git-guides/install-git)
|
||||
|
||||
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md).
|
||||
|
||||
If you have a "normal" installation, you should create a totally separate virtual environment for the git-based installation, else the two may interfere.
|
||||
|
||||
> **Why do I need the frontend toolchain**?
|
||||
>
|
||||
> The InvokeAI project uses trunk-based development. That means our `main` branch is the development branch, and releases are tags on that branch. Because development is very active, we don't keep an updated build of the UI in `main` - we only build it for production releases.
|
||||
>
|
||||
> That means that between releases, to have a functioning application when running directly from the repo, you will need to run the UI in dev mode or build it regularly (any time the UI code changes).
|
||||
|
||||
1. Create a fork of the InvokeAI repository through the GitHub UI or [this link](https://github.com/invoke-ai/InvokeAI/fork)
|
||||
1. From the command line, run this command:
|
||||
2. From the command line, run this command:
|
||||
```bash
|
||||
git clone https://github.com/<your_github_username>/InvokeAI.git
|
||||
```
|
||||
@@ -305,10 +315,10 @@ Guide](https://github.com/git-guides/install-git)
|
||||
This will create a directory named `InvokeAI` and populate it with the
|
||||
full source code from your fork of the InvokeAI repository.
|
||||
|
||||
2. Activate the InvokeAI virtual environment as per step (4) of the manual
|
||||
3. Activate the InvokeAI virtual environment as per step (4) of the manual
|
||||
installation protocol (important!)
|
||||
|
||||
3. Enter the InvokeAI repository directory and run one of these
|
||||
4. Enter the InvokeAI repository directory and run one of these
|
||||
commands, based on your GPU:
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
@@ -334,11 +344,15 @@ installation protocol (important!)
|
||||
Be sure to pass `-e` (for an editable install) and don't forget the
|
||||
dot ("."). It is part of the command.
|
||||
|
||||
You can now run `invokeai` and its related commands. The code will be
|
||||
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md) and do a production build of the UI as described.
|
||||
|
||||
6. You can now run `invokeai` and its related commands. The code will be
|
||||
read from the repository, so that you can edit the .py source files
|
||||
and watch the code's behavior change.
|
||||
|
||||
4. If you wish to contribute to the InvokeAI project, you are
|
||||
When you pull in new changes to the repo, be sure to re-build the UI.
|
||||
|
||||
7. If you wish to contribute to the InvokeAI project, you are
|
||||
encouraged to establish a GitHub account and "fork"
|
||||
https://github.com/invoke-ai/InvokeAI into your own copy of the
|
||||
repository. You can then use GitHub functions to create and submit
|
||||
|
||||
@@ -121,18 +121,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Enhance Image (simple adjustments)
|
||||
|
||||
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||
|
||||
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/image-enhance-node
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Generative Grammar-Based Prompt Nodes
|
||||
|
||||
@@ -153,16 +141,26 @@ This includes 3 Nodes:
|
||||
|
||||
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||
|
||||
This includes 4 Nodes:
|
||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||
This includes 14 Nodes:
|
||||
- *Adjust Image Hue Plus* - Rotate the hue of an image in one of several different color spaces.
|
||||
- *Blend Latents/Noise (Masked)* - Use a mask to blend part of one latents tensor [including Noise outputs] into another. Can be used to "renoise" sections during a multi-stage [masked] denoising process.
|
||||
- *Enhance Image* - Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||
- *Equivalent Achromatic Lightness* - Calculates image lightness accounting for Helmholtz-Kohlrausch effect based on a method described by High, Green, and Nussbaum (2023).
|
||||
- *Text to Mask (Clipseg)* - Input a prompt and an image to generate a mask representing areas of the image matched by the prompt.
|
||||
- *Text to Mask Advanced (Clipseg)* - Output up to four prompt masks combined with logical "and", logical "or", or as separate channels of an RGBA image.
|
||||
- *Image Layer Blend* - Perform a layered blend of two images using alpha compositing. Opacity of top layer is selectable, with optional mask and several different blend modes/color spaces.
|
||||
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
||||
- *Image Dilate or Erode* - Dilate or expand a mask (or any image!). This is equivalent to an expand/contract operation.
|
||||
- *Image Value Thresholds* - Clip an image to pure black/white beyond specified thresholds.
|
||||
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
- *Shadows/Highlights/Midtones* - Extract three masks (with adjustable hard or soft thresholds) representing shadows, midtones, and highlights regions of an image.
|
||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/composition-nodes
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
**Nodes and Output Examples:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Size Stepper Nodes
|
||||
|
||||
@@ -49,7 +49,7 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
|
||||
@@ -45,17 +45,13 @@ async def upload_image(
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
metadata: Optional[str] = None
|
||||
workflow: Optional[str] = None
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
metadata = pil_image.info.get("invokeai_metadata", None)
|
||||
workflow = pil_image.info.get("invokeai_workflow", None)
|
||||
except Exception:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
@@ -67,8 +63,6 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=metadata,
|
||||
workflow=workflow,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
|
||||
@@ -146,7 +146,8 @@ async def update_model(
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||
default=None,
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
@@ -8,7 +8,6 @@ app_config.parse_args()
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from inspect import signature
|
||||
@@ -41,7 +40,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@@ -223,7 +224,7 @@ def invoke_api():
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
@@ -242,7 +243,7 @@ def invoke_api():
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
log = logging.getLogger(logname)
|
||||
log = InvokeAILogger.get_logger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
|
||||
@@ -7,8 +7,6 @@ from .services.config import InvokeAIAppConfig
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import argparse
|
||||
@@ -61,8 +59,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
|
||||
@@ -71,12 +71,7 @@ class FieldDescriptions:
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_dict_collection = "Collection of MetadataDicts"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
@@ -180,12 +175,8 @@ class UIType(str, Enum):
|
||||
Scheduler = "Scheduler"
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
MetadataField = "MetadataField"
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -631,8 +622,23 @@ class BaseInvocation(ABC, BaseModel):
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
||||
)
|
||||
workflow: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
def validate_workflow_is_json(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
json.loads(v)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError("Workflow must be valid JSON")
|
||||
return v
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
|
||||
@@ -737,19 +743,3 @@ def invocation_output(
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[str] = InputField(
|
||||
default=None, description=FieldDescriptions.workflow, ui_type=UIType.WorkflowField
|
||||
)
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
def validate_workflow_is_json(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
json.loads(v)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError("Workflow must be valid JSON")
|
||||
return v
|
||||
|
||||
@@ -25,7 +25,6 @@ from controlnet_aux import (
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
@@ -39,7 +38,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -129,7 +127,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
@@ -152,7 +150,6 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,21 +7,13 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
)
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@@ -45,7 +37,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
@@ -63,7 +55,6 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -75,7 +66,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
@@ -97,7 +88,6 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -109,7 +99,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
||||
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
@@ -151,7 +141,6 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -163,7 +152,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
@@ -183,7 +172,6 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -195,7 +183,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
@@ -214,7 +202,6 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -229,7 +216,7 @@ IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
@@ -247,7 +234,6 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -262,7 +248,7 @@ IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
@@ -280,7 +266,6 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -292,7 +277,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
@@ -315,7 +300,6 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -347,13 +331,16 @@ PIL_RESAMPLING_MAP = {
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -372,7 +359,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -384,7 +371,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
@@ -414,7 +401,6 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -426,7 +412,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@@ -448,7 +434,6 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -460,7 +445,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@@ -482,7 +467,6 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -494,10 +478,13 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -518,7 +505,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -538,11 +525,14 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -554,7 +544,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -566,7 +556,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
@@ -600,7 +590,6 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -614,7 +603,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
@@ -633,7 +622,6 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -645,7 +633,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
@@ -744,7 +732,6 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -756,7 +743,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -784,7 +771,6 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -860,7 +846,7 @@ CHANNEL_FORMATS = {
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -894,7 +880,6 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -931,7 +916,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
"""Scale a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -971,7 +956,6 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
workflow=self.workflow,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -991,11 +975,16 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class SaveImageInvocation(BaseInvocation):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -1008,7 +997,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@@ -14,7 +13,7 @@ from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@@ -120,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -144,7 +143,6 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -156,7 +154,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -181,7 +179,6 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -195,7 +192,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -235,7 +232,6 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@@ -247,7 +243,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -264,8 +260,6 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -276,7 +270,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class CV2InfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -293,8 +287,6 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -23,7 +23,7 @@ from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
@@ -62,7 +62,6 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -622,7 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
@invocation(
|
||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@@ -635,6 +634,11 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@@ -703,7 +707,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.model import LoRAModelField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
@@ -26,78 +25,159 @@ class LoRAMetadataField(BaseModelExcludeNull):
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
created_by: Optional[str] = Field(description="The name of the creator of the image")
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
"""An image's generation metadata"""
|
||||
|
||||
metadata: Optional[dict] = Field(default=None, description="The metadata associated with the image")
|
||||
workflow: Optional[dict] = Field(default=None, description="The workflow associated with the image")
|
||||
metadata: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||
)
|
||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||
|
||||
|
||||
class MetadataItem(BaseModel):
|
||||
label: str = Field(description=FieldDescriptions.metadata_item_label)
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
@invocation_output("metadata_accumulator_output")
|
||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
"""The output of the MetadataAccumulator node"""
|
||||
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@invocation(
|
||||
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||
)
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
item: MetadataItem = OutputField(description="Metadata Item")
|
||||
generation_mode: str = InputField(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||
width: int = InputField(description="The width parameter")
|
||||
height: int = InputField(description="The height parameter")
|
||||
seed: int = InputField(description="The seed used for noise generation")
|
||||
rand_device: str = InputField(description="The device used for random number generation")
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The positive style prompt parameter",
|
||||
)
|
||||
negative_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The negative style prompt parameter",
|
||||
)
|
||||
|
||||
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataItemInvocation(BaseInvocation):
|
||||
"""Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value."""
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
refiner_cfg_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of steps used for the refiner",
|
||||
)
|
||||
refiner_scheduler: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_positive_aesthetic_score: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_score: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_start: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
|
||||
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
||||
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
|
||||
return MetadataItemOutput(item=MetadataItem(label=self.label, value=self.value))
|
||||
|
||||
|
||||
class MetadataDict(BaseModel):
|
||||
"""Accepts a single MetadataItem or collection of MetadataItems (use a Collect node)."""
|
||||
|
||||
data: dict[str, Any] = Field(description="Metadata dict")
|
||||
|
||||
|
||||
@invocation_output("metadata_dict")
|
||||
class MetadataDictOutput(BaseInvocationOutput):
|
||||
metadata_dict: MetadataDict = OutputField(description="Metadata Dict")
|
||||
|
||||
|
||||
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataInvocation(BaseInvocation):
|
||||
"""Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict."""
|
||||
|
||||
items: Union[list[MetadataItem], MetadataItem] = InputField(description=FieldDescriptions.metadata_item_polymorphic)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataDictOutput:
|
||||
if isinstance(self.items, MetadataItem):
|
||||
# single metadata item
|
||||
data = {self.items.label: self.items.value}
|
||||
else:
|
||||
# collection of metadata items
|
||||
data = {item.label: item.value for item in self.items}
|
||||
|
||||
data.update({"app_version": __version__})
|
||||
return MetadataDictOutput(metadata_dict=(MetadataDict(data=data)))
|
||||
|
||||
|
||||
@invocation("merge_metadata_dict", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MergeMetadataDictInvocation(BaseInvocation):
|
||||
"""Merged a collection of MetadataDict into a single MetadataDict."""
|
||||
|
||||
collection: list[MetadataDict] = InputField(description=FieldDescriptions.metadata_dict_collection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataDictOutput:
|
||||
data = {}
|
||||
for item in self.collection:
|
||||
data.update(item.data)
|
||||
|
||||
return MetadataDictOutput(metadata_dict=MetadataDict(data=data))
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataDict] = InputField(default=None, description=FieldDescriptions.metadata)
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
@@ -28,7 +28,6 @@ from .baseinvocation import (
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
WithWorkflow,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
@@ -322,7 +321,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@@ -333,6 +332,11 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@@ -371,7 +375,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
||||
@@ -251,9 +251,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
class ImageInvocation(
|
||||
BaseInvocation,
|
||||
):
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
@@ -7,12 +7,11 @@ import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from invokeai.app.invocations.metadata import WithMetadata
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@@ -25,7 +24,7 @@ ESRGAN_MODELS = Literal[
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@@ -107,7 +106,6 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.data if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
||||
@@ -117,6 +117,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if from_type is int and to_type is float:
|
||||
return True
|
||||
|
||||
# allow int|float -> str, pydantic will cast for us
|
||||
if (from_type is int or from_type is float) and to_type is str:
|
||||
return True
|
||||
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
@@ -421,14 +425,6 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_destination_field_Any(self, edge: Edge) -> bool:
|
||||
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
|
||||
|
||||
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
|
||||
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
|
||||
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
@@ -481,19 +477,8 @@ class Graph(BaseModel):
|
||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate that we are not connecting collector to iterator (currently unsupported)
|
||||
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
|
||||
raise InvalidEdgeError(
|
||||
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
||||
if (
|
||||
isinstance(from_node, CollectInvocation)
|
||||
and edge.source.field == "collection"
|
||||
and not self._is_destination_field_list_of_Any(edge)
|
||||
and not self._is_destination_field_Any(edge)
|
||||
):
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
@@ -726,15 +711,16 @@ class Graph(BaseModel):
|
||||
# Get the input root type
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all(
|
||||
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
|
||||
for f in output_fields
|
||||
):
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -59,7 +59,7 @@ class ImageFileStorageBase(ABC):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[Union[str, dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
@@ -109,7 +109,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[Union[str, dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
@@ -119,10 +119,20 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata) if type(metadata) is dict else metadata)
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
if metadata is not None or workflow is not None:
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
else:
|
||||
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
||||
# TODO: retain non-invokeai metadata on save...
|
||||
original_metadata = image.info.get("invokeai_metadata", None)
|
||||
if original_metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", original_metadata)
|
||||
original_workflow = image.info.get("invokeai_workflow", None)
|
||||
if original_workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
|
||||
|
||||
@@ -3,12 +3,11 @@ import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, Union, cast
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
||||
|
||||
@@ -82,7 +81,7 @@ class ImageRecordStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
"""Gets an image's metadata'."""
|
||||
pass
|
||||
|
||||
@@ -135,8 +134,7 @@ class ImageRecordStorageBase(ABC):
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[Union[str, dict]],
|
||||
workflow: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
@@ -206,13 +204,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
if "workflow" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN workflow TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -278,31 +269,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata, workflow FROM images
|
||||
SELECT images.metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
|
||||
if not result:
|
||||
return ImageMetadata()
|
||||
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
workflow_raw = cast(Optional[str], as_dict.get("workflow", None))
|
||||
|
||||
return ImageMetadata(
|
||||
metadata=json.loads(metadata_raw) if metadata_raw is not None else None,
|
||||
workflow=json.loads(workflow_raw) if workflow_raw is not None else None,
|
||||
)
|
||||
if not result or not result[0]:
|
||||
return None
|
||||
return json.loads(result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
@@ -537,15 +519,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[Union[str, dict]],
|
||||
workflow: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json: Optional[str] = None
|
||||
if metadata is not None:
|
||||
metadata_json = metadata if type(metadata) is str else json.dumps(metadata)
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -558,11 +537,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
workflow,
|
||||
is_intermediate,
|
||||
starred
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@@ -573,7 +551,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
workflow,
|
||||
is_intermediate,
|
||||
starred,
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
@@ -29,6 +29,7 @@ from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
@@ -70,7 +71,7 @@ class ImageServiceABC(ABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[Union[str, dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
@@ -195,7 +196,7 @@ class ImageService(ImageServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[Union[str, dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
@@ -233,7 +234,6 @@ class ImageService(ImageServiceABC):
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
workflow=workflow,
|
||||
session_id=session_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
@@ -311,7 +311,23 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
return self._services.image_records.get_metadata(image_name)
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from queue import Queue
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from time import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
@@ -7,22 +10,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class CachedItem:
|
||||
invocation_output: BaseInvocationOutput = field(compare=False)
|
||||
invocation_output_json: str = field(compare=False)
|
||||
|
||||
|
||||
class MemoryInvocationCache(InvocationCacheBase):
|
||||
_cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
||||
_cache: OrderedDict[Union[int, str], CachedItem]
|
||||
_max_cache_size: int
|
||||
_disabled: bool
|
||||
_hits: int
|
||||
_misses: int
|
||||
_cache_ids: Queue
|
||||
_invoker: Invoker
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, max_cache_size: int = 0) -> None:
|
||||
self._cache = dict()
|
||||
self._cache = OrderedDict()
|
||||
self._max_cache_size = max_cache_size
|
||||
self._disabled = False
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._cache_ids = Queue()
|
||||
self._lock = Lock()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -32,80 +41,87 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
if self._max_cache_size == 0 or self._disabled:
|
||||
return
|
||||
|
||||
item = self._cache.get(key, None)
|
||||
if item is not None:
|
||||
self._hits += 1
|
||||
return item[0]
|
||||
self._misses += 1
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0 or self._disabled:
|
||||
return None
|
||||
item = self._cache.get(key, None)
|
||||
if item is not None:
|
||||
self._hits += 1
|
||||
self._cache.move_to_end(key)
|
||||
return item.invocation_output
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
if self._max_cache_size == 0 or self._disabled:
|
||||
return
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0 or self._disabled or key in self._cache:
|
||||
return
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
|
||||
if key not in self._cache:
|
||||
self._cache[key] = (invocation_output, invocation_output.json())
|
||||
self._cache_ids.put(key)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
try:
|
||||
self._cache.pop(self._cache_ids.get())
|
||||
except KeyError:
|
||||
# this means the cache_ids are somehow out of sync w/ the cache
|
||||
pass
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
for _ in range(number_to_delete):
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
def _delete(self, key: Union[int, str]) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._cache.clear()
|
||||
self._misses = 0
|
||||
self._hits = 0
|
||||
|
||||
self._cache.clear()
|
||||
self._cache_ids = Queue()
|
||||
self._misses = 0
|
||||
self._hits = 0
|
||||
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
@staticmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
|
||||
def disable(self) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = True
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = True
|
||||
|
||||
def enable(self) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = False
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = False
|
||||
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
return InvocationCacheStatus(
|
||||
hits=self._hits,
|
||||
misses=self._misses,
|
||||
enabled=not self._disabled and self._max_cache_size > 0,
|
||||
size=len(self._cache),
|
||||
max_size=self._max_cache_size,
|
||||
)
|
||||
with self._lock:
|
||||
return InvocationCacheStatus(
|
||||
hits=self._hits,
|
||||
misses=self._misses,
|
||||
enabled=not self._disabled and self._max_cache_size > 0,
|
||||
size=len(self._cache),
|
||||
max_size=self._max_cache_size,
|
||||
)
|
||||
|
||||
def _delete_by_match(self, to_match: str) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
||||
keys_to_delete = set()
|
||||
for key, value_tuple in self._cache.items():
|
||||
if to_match in value_tuple[1]:
|
||||
keys_to_delete.add(key)
|
||||
|
||||
if not keys_to_delete:
|
||||
return
|
||||
|
||||
for key in keys_to_delete:
|
||||
self.delete(key)
|
||||
|
||||
self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
keys_to_delete = set()
|
||||
for key, cached_item in self._cache.items():
|
||||
if to_match in cached_item.invocation_output_json:
|
||||
keys_to_delete.add(key)
|
||||
if not keys_to_delete:
|
||||
return
|
||||
for key in keys_to_delete:
|
||||
self._delete(key)
|
||||
self._invoker.services.logger.debug(
|
||||
f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
|
||||
)
|
||||
|
||||
@@ -47,20 +47,27 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
|
||||
"data"
|
||||
]["graph_execution_state_id"]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "batch_enqueued":
|
||||
self._poll_now()
|
||||
case "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self.__resume_event.is_set():
|
||||
|
||||
@@ -59,13 +59,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
case "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
|
||||
@@ -93,7 +93,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
@@ -894,7 +894,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.getLogger(name="InvokeAI")
|
||||
logger = InvokeAILogger.get_logger(name="InvokeAI")
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
@@ -47,8 +47,14 @@ Config_preamble = """
|
||||
|
||||
LEGACY_CONFIGS = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
@@ -69,14 +75,6 @@ LEGACY_CONFIGS = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
"""Class for listing models to be installed/removed"""
|
||||
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
@@ -94,6 +92,7 @@ class ModelLoadInfo:
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
default: bool = False
|
||||
requires: Optional[List[str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
@@ -131,8 +130,6 @@ class ModelInstall(object):
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = [x for x in self.mgr.list_models()]
|
||||
# suppresses autoloaded models
|
||||
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
@@ -164,9 +161,12 @@ class ModelInstall(object):
|
||||
|
||||
def list_models(self, model_type):
|
||||
installed = self.mgr.list_models(model_type=model_type)
|
||||
print()
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
print(f"{'Model Key':50} Model Path")
|
||||
for i in installed:
|
||||
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
|
||||
print()
|
||||
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
@@ -204,6 +204,8 @@ class ModelInstall(object):
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
self._remove_installed(selections.install_models)
|
||||
self._add_required_models(selections.install_models)
|
||||
for path in selections.install_models:
|
||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||
try:
|
||||
@@ -263,6 +265,26 @@ class ModelInstall(object):
|
||||
|
||||
return models_installed
|
||||
|
||||
def _remove_installed(self, model_list: List[str]):
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
key = self.reverse_paths.get(path)
|
||||
if key and all_models[key].installed:
|
||||
logger.warning(f"{path} already installed. Skipping.")
|
||||
model_list.remove(path)
|
||||
|
||||
def _add_required_models(self, model_list: List[str]):
|
||||
additional_models = []
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
if not (key := self.reverse_paths.get(path)):
|
||||
continue
|
||||
for requirement in all_models[key].requires:
|
||||
requirement_key = self.reverse_paths.get(requirement)
|
||||
if not all_models[requirement_key].installed:
|
||||
additional_models.append(requirement)
|
||||
model_list.extend(additional_models)
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||
@@ -286,7 +308,7 @@ class ModelInstall(object):
|
||||
location = download_with_resume(url, Path(staging))
|
||||
if not location:
|
||||
logger.error(f"Unable to download {url}. Skipping.")
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
models_path = shutil.move(location, dest)
|
||||
@@ -393,7 +415,7 @@ class ModelInstall(object):
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
elif info.base_type == BaseModelType.StableDiffusion2:
|
||||
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir,
|
||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||
@@ -492,7 +514,7 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger = InvokeAILogger.getLogger("InvokeAI")
|
||||
logger = InvokeAILogger.get_logger("InvokeAI")
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
|
||||
|
||||
|
||||
@@ -1279,12 +1279,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
||||
|
||||
if (
|
||||
model_version == BaseModelType.StableDiffusion2
|
||||
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
|
||||
and original_config["model"]["params"].get("parameterization") == "v"
|
||||
):
|
||||
prediction_type = "v_prediction"
|
||||
upcast_attention = True
|
||||
image_size = 768
|
||||
image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512
|
||||
else:
|
||||
prediction_type = "epsilon"
|
||||
upcast_attention = False
|
||||
|
||||
@@ -90,8 +90,7 @@ class ModelProbe(object):
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||
between V2-Base and V2-768 SD models.
|
||||
the path to the model and returns the SchedulerPredictionType.
|
||||
"""
|
||||
if model_path:
|
||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
@@ -305,25 +304,36 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
else:
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Return model prediction type."""
|
||||
# if there is a .yaml associated with this checkpoint, then we do not need
|
||||
# to probe for the prediction type as it will be ignored.
|
||||
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
||||
return None
|
||||
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if (
|
||||
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
||||
): # if a .yaml config file exists, then this step not needed
|
||||
return self.helper(self.checkpoint_path)
|
||||
else:
|
||||
return None
|
||||
if type == BaseModelType.StableDiffusion2:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||
|
||||
elif type == BaseModelType.StableDiffusion1:
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
@@ -71,7 +71,13 @@ class ModelSearch(ABC):
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
]
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""
|
||||
invokeai.backend.util.logging
|
||||
"""invokeai.backend.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages
|
||||
|
||||
@@ -9,9 +8,9 @@ Usage:
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||
logger = InvokeAILogger.get_logger(name='InvokeAI') // Initialization
|
||||
(or)
|
||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||
logger = InvokeAILogger.get_logger(__name__) // To use the filename
|
||||
logger.configure()
|
||||
|
||||
logger.critical('this is critical') // Critical Message
|
||||
@@ -34,13 +33,13 @@ IAILogger.debug('this is a debugging message')
|
||||
## Configuration
|
||||
|
||||
The default configuration will print to stderr on the console. To add
|
||||
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
|
||||
additional logging handlers, call get_logger with an initialized InvokeAIAppConfig
|
||||
object:
|
||||
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=config)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
### Three command-line options control logging:
|
||||
|
||||
@@ -173,6 +172,7 @@ InvokeAI:
|
||||
log_level: info
|
||||
log_format: color
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import logging.handlers
|
||||
@@ -193,39 +193,35 @@ except ImportError:
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
||||
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.getLogger().disable(level)
|
||||
InvokeAILogger.get_logger().disable(level)
|
||||
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
|
||||
def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
InvokeAILogger.get_logger().basicConfig(**kwargs)
|
||||
|
||||
|
||||
_FACILITY_MAP = (
|
||||
@@ -351,7 +347,7 @@ class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(
|
||||
def get_logger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
if name in cls.loggers:
|
||||
@@ -360,13 +356,13 @@ class InvokeAILogger(object):
|
||||
else:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.getLoggers(config):
|
||||
for ch in cls.get_loggers(config):
|
||||
logger.addHandler(ch)
|
||||
cls.loggers[name] = logger
|
||||
return cls.loggers[name]
|
||||
|
||||
@classmethod
|
||||
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
handler_strs = config.log_handlers
|
||||
handlers = list()
|
||||
for handler in handler_strs:
|
||||
|
||||
@@ -103,3 +103,35 @@ sd-1/lora/LowRA:
|
||||
recommended: True
|
||||
sd-1/lora/Ink scenery:
|
||||
path: https://civitai.com/api/download/models/83390
|
||||
sd-1/ip_adapter/ip_adapter_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_sd15
|
||||
recommended: True
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_plus_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
|
||||
repo_id: InvokeAI/ip_adapter_plus_face_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
|
||||
sdxl/ip_adapter/ip_adapter_sdxl:
|
||||
repo_id: InvokeAI/ip_adapter_sdxl
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
description: IP-Adapter for SDXL models
|
||||
any/clip_vision/ip_adapter_sd_image_encoder:
|
||||
repo_id: InvokeAI/ip_adapter_sd_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SD-1/2 models
|
||||
any/clip_vision/ip_adapter_sdxl_image_encoder:
|
||||
repo_id: InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SDXL models
|
||||
|
||||
80
invokeai/configs/stable-diffusion/v1-inference-v.yaml
Normal file
80
invokeai/configs/stable-diffusion/v1-inference-v.yaml
Normal file
@@ -0,0 +1,80 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
@@ -45,7 +45,7 @@ from invokeai.frontend.install.widgets import (
|
||||
)
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
# build a table mapping all non-printable characters to None
|
||||
# for stripping control characters
|
||||
@@ -101,11 +101,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"STARTER MODELS",
|
||||
"MAIN MODELS",
|
||||
"CONTROLNETS",
|
||||
"IP-ADAPTERS",
|
||||
"LORA/LYCORIS",
|
||||
"TEXTUAL INVERSION",
|
||||
],
|
||||
value=[self.current_tab],
|
||||
columns=5,
|
||||
columns=6,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
@@ -130,6 +131,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.ipadapter_models = self.add_model_widgets(
|
||||
model_type=ModelType.IPAdapter,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.Lora,
|
||||
@@ -343,6 +351,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
@@ -532,6 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
@@ -553,6 +563,25 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
if downloads := section.get("download_ids"):
|
||||
selections.install_models.extend(downloads.value.split())
|
||||
|
||||
# NOT NEEDED - DONE IN BACKEND NOW
|
||||
# # special case for the ipadapter_models. If any of the adapters are
|
||||
# # chosen, then we add the corresponding encoder(s) to the install list.
|
||||
# section = self.ipadapter_models
|
||||
# if section.get("models_selected"):
|
||||
# selected_adapters = [
|
||||
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
|
||||
# ]
|
||||
# encoders = []
|
||||
# if any(["sdxl" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sdxl_image_encoder")
|
||||
# if any(["sd15" in x for x in selected_adapters]):
|
||||
# encoders.append("ip_adapter_sd_image_encoder")
|
||||
# for encoder in encoders:
|
||||
# key = f"any/clip_vision/{encoder}"
|
||||
# repo_id = f"InvokeAI/{encoder}"
|
||||
# if key not in self.all_models:
|
||||
# selections.install_models.append(repo_id)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, opt):
|
||||
@@ -652,7 +681,7 @@ def process_and_execute(
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
@@ -765,7 +794,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||
|
||||
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -574,7 +574,7 @@
|
||||
"onnxModels": "Onnx",
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)",
|
||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
|
||||
"quickAdd": "Quick Add",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
"lightMode": "Light Mode",
|
||||
"linear": "Linear",
|
||||
"load": "Load",
|
||||
"loading": "Loading",
|
||||
"loading": "Loading $t({{noun}})...",
|
||||
"loadingInvokeAI": "Loading Invoke AI",
|
||||
"learnMore": "Learn More",
|
||||
"modelManager": "Model Manager",
|
||||
@@ -655,7 +655,7 @@
|
||||
"onnxModels": "Onnx",
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)",
|
||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
|
||||
"quickAdd": "Quick Add",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
|
||||
@@ -17,7 +17,10 @@ import '../../i18n';
|
||||
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||
import { $customStarUI, CustomStarUi } from 'app/store/nanostores/customStarUI';
|
||||
import { $headerComponent } from 'app/store/nanostores/headerComponent';
|
||||
import { $queueId, DEFAULT_QUEUE_ID } from 'features/queue/store/nanoStores';
|
||||
import {
|
||||
$queueId,
|
||||
DEFAULT_QUEUE_ID,
|
||||
} from 'features/queue/store/queueNanoStore';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
|
||||
@@ -81,3 +81,38 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type IAINoImageFallbackWithSpinnerProps = FlexProps & {
|
||||
label?: string;
|
||||
};
|
||||
|
||||
export const IAINoContentFallbackWithSpinner = (
|
||||
props: IAINoImageFallbackWithSpinnerProps
|
||||
) => {
|
||||
const { sx, ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
userSelect: 'none',
|
||||
opacity: 0.7,
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
...sx,
|
||||
}}
|
||||
{...rest}
|
||||
>
|
||||
<Spinner size="xl" />
|
||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -44,7 +44,7 @@ const IAIMantineMultiSelect = forwardRef((props: IAIMultiSelectProps, ref) => {
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||
<FormControl ref={ref} isDisabled={disabled}>
|
||||
<FormControl ref={ref} isDisabled={disabled} position="static">
|
||||
{label && <FormLabel>{label}</FormLabel>}
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
|
||||
@@ -70,11 +70,10 @@ const IAIMantineSearchableSelect = forwardRef((props: IAISelectProps, ref) => {
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<FormControl ref={ref} isDisabled={disabled}>
|
||||
<FormControl ref={ref} isDisabled={disabled} position="static">
|
||||
{label && <FormLabel>{label}</FormLabel>}
|
||||
<Select
|
||||
ref={inputRef}
|
||||
withinPortal
|
||||
disabled={disabled}
|
||||
searchValue={searchValue}
|
||||
onSearchChange={setSearchValue}
|
||||
|
||||
@@ -22,7 +22,12 @@ const IAIMantineSelect = forwardRef((props: IAISelectProps, ref) => {
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<FormControl ref={ref} isRequired={required} isDisabled={disabled}>
|
||||
<FormControl
|
||||
ref={ref}
|
||||
isRequired={required}
|
||||
isDisabled={disabled}
|
||||
position="static"
|
||||
>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Select disabled={disabled} ref={inputRef} styles={styles} {...rest} />
|
||||
</FormControl>
|
||||
|
||||
@@ -254,4 +254,5 @@ export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||
mediapipe: 'mediapipe_face_processor',
|
||||
pidi: 'pidi_image_processor',
|
||||
zoe: 'zoe_depth_image_processor',
|
||||
color: 'color_map_image_processor',
|
||||
};
|
||||
|
||||
@@ -28,7 +28,7 @@ import {
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -41,10 +41,9 @@ import {
|
||||
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataQuery,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToImg2Img } from '../../store/actions';
|
||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||
|
||||
@@ -93,6 +92,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
shouldShowImageDetails,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldFetchMetadataFromApi,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
@@ -107,10 +107,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
lastSelectedImage?.image_name ?? skipToken
|
||||
);
|
||||
|
||||
const [debouncedImageName] = useDebounce(lastSelectedImage?.image_name, 300);
|
||||
const getMetadataArg = useMemo(() => {
|
||||
if (lastSelectedImage) {
|
||||
return { image: lastSelectedImage, shouldFetchMetadataFromApi };
|
||||
} else {
|
||||
return skipToken;
|
||||
}
|
||||
}, [lastSelectedImage, shouldFetchMetadataFromApi]);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataQuery(
|
||||
debouncedImageName ?? skipToken,
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
getMetadataArg,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
@@ -281,7 +287,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!metadata?.seed}
|
||||
isDisabled={metadata?.seed === null || metadata?.seed === undefined}
|
||||
onClick={handleUseSeed}
|
||||
/>
|
||||
<IAIIconButton
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
imagesToChangeSelected,
|
||||
@@ -33,12 +32,12 @@ import {
|
||||
import { FaCircleNodes } from 'react-icons/fa6';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import {
|
||||
useGetImageMetadataQuery,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
@@ -54,12 +53,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const [debouncedImageName] = useDebounce(imageDTO.image_name, 300);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataQuery(
|
||||
debouncedImageName ?? skipToken,
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
{ image: imageDTO, shouldFetchMetadataFromApi },
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
|
||||
@@ -9,15 +9,15 @@ import {
|
||||
Tabs,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import DataViewer from './DataViewer';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
@@ -31,10 +31,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// });
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [debouncedImageName] = useDebounce(image.image_name, 300);
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
|
||||
const { metadata, workflow } = useGetImageMetadataQuery(
|
||||
debouncedImageName ?? skipToken,
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||
{ image, shouldFetchMetadataFromApi },
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
||||
import { useWithWorkflow } from 'features/nodes/hooks/useWithWorkflow';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const withWorkflow = useWithWorkflow(nodeId);
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const embedWorkflow = useEmbedWorkflow(nodeId);
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -21,7 +21,7 @@ const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!withWorkflow) {
|
||||
if (!hasImageOutput) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import InvocationNodeFooter from './InvocationNodeFooter';
|
||||
import InvocationNodeHeader from './InvocationNodeHeader';
|
||||
import InputField from './fields/InputField';
|
||||
import OutputField from './fields/OutputField';
|
||||
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@@ -20,6 +21,7 @@ type Props = {
|
||||
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
|
||||
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
|
||||
const withFooter = useWithFooter(nodeId);
|
||||
const outputFieldNames = useOutputFieldNames(nodeId);
|
||||
|
||||
return (
|
||||
@@ -41,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
h: 'full',
|
||||
py: 2,
|
||||
gap: 1,
|
||||
borderBottomRadius: 0,
|
||||
borderBottomRadius: withFooter ? 0 : 'base',
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}>
|
||||
@@ -74,7 +76,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<InvocationNodeFooter nodeId={nodeId} />
|
||||
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
|
||||
</>
|
||||
)}
|
||||
</NodeWrapper>
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
||||
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||
import UseCacheCheckbox from './UseCacheCheckbox';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@@ -12,6 +13,7 @@ type Props = {
|
||||
|
||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
@@ -25,7 +27,7 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<UseCacheCheckbox nodeId={nodeId} />
|
||||
{isCacheEnabled && <UseCacheCheckbox nodeId={nodeId} />}
|
||||
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
|
||||
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
|
||||
</Flex>
|
||||
|
||||
@@ -3,12 +3,7 @@ import graphlib from '@dagrejs/graphlib';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from '../types/constants';
|
||||
import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
|
||||
/**
|
||||
@@ -23,11 +18,6 @@ export const useIsValidConnection = () => {
|
||||
);
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
}
|
||||
|
||||
const edges = flow.getEdges();
|
||||
const nodes = flow.getNodes();
|
||||
// Connection must have valid targets
|
||||
@@ -52,6 +42,16 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (source === target) {
|
||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
@@ -76,63 +76,8 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
const isEitherAnyType = sourceType === 'Any' || targetType === 'Any';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isEitherAnyType
|
||||
);
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
|
||||
@@ -1,31 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { some } from 'lodash-es';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useMemo } from 'react';
|
||||
import { FOOTER_FIELDS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import { useHasImageOutput } from './useHasImageOutput';
|
||||
|
||||
export const useHasImageOutputs = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return some(node.data.outputs, (output) =>
|
||||
FOOTER_FIELDS.includes(output.type)
|
||||
);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
export const useWithFooter = (nodeId: string) => {
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
|
||||
const withFooter = useMemo(
|
||||
() => hasImageOutput || isCacheEnabled,
|
||||
[hasImageOutput, isCacheEnabled]
|
||||
);
|
||||
|
||||
const withFooter = useAppSelector(selector);
|
||||
return withFooter;
|
||||
};
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useWithWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate) {
|
||||
return false;
|
||||
}
|
||||
return nodeTemplate.withWorkflow;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const withWorkflow = useAppSelector(selector);
|
||||
return withWorkflow;
|
||||
};
|
||||
@@ -1,15 +1,10 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { HandleType } from 'reactflow';
|
||||
import i18n from 'i18next';
|
||||
import { HandleType } from 'reactflow';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
@@ -70,67 +65,8 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
const isEitherAnyType = sourceType === 'Any' || targetType === 'Any';
|
||||
|
||||
if (
|
||||
!(
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isEitherAnyType
|
||||
)
|
||||
) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
|
||||
export const validateSourceAndTargetTypes = (
|
||||
sourceType: FieldType,
|
||||
targetType: FieldType
|
||||
) => {
|
||||
if (sourceType === targetType) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
const isIntOrFloatToString =
|
||||
(sourceType === 'integer' || sourceType === 'float') &&
|
||||
targetType === 'string';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isIntOrFloatToString
|
||||
);
|
||||
};
|
||||
@@ -31,8 +31,6 @@ export const COLLECTION_TYPES: FieldType[] = [
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
'MetadataItemCollection',
|
||||
'MetadataDictCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
@@ -45,7 +43,6 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
'MetadataItemPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
@@ -73,8 +70,6 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
MetadataItem: 'MetadataItemCollection',
|
||||
MetadataDict: 'MetadataDictCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
@@ -92,7 +87,6 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
MetadataItem: 'MetadataItemPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
@@ -105,7 +99,6 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
MetadataItemPolymorphic: 'MetadataItem',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
@@ -138,37 +131,6 @@ export const isPolymorphicItemType = (
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
Any: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Any',
|
||||
},
|
||||
MetadataDict: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata dict.',
|
||||
title: 'Metadata Dict',
|
||||
},
|
||||
MetadataDictCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'A collection of metadata dicts.',
|
||||
title: 'Metadata Dict Collection',
|
||||
},
|
||||
MetadataItem: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata item.',
|
||||
title: 'Metadata Item',
|
||||
},
|
||||
MetadataItemCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Metadata Item Collection',
|
||||
},
|
||||
MetadataItemPolymorphic: {
|
||||
color: 'gray.500',
|
||||
description:
|
||||
'MetadataItem or MetadataItemCollection field types are accepted.',
|
||||
title: 'Metadata Item Polymorphic',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanDescription'),
|
||||
|
||||
@@ -54,10 +54,6 @@ export type InvocationTemplate = {
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
/**
|
||||
* Whether or not this invocation supports workflows
|
||||
*/
|
||||
withWorkflow: boolean;
|
||||
/**
|
||||
* The invocation's version.
|
||||
*/
|
||||
@@ -76,7 +72,6 @@ export type FieldUIConfig = {
|
||||
|
||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||
export const zFieldType = z.enum([
|
||||
'Any',
|
||||
'BoardField',
|
||||
'boolean',
|
||||
'BooleanCollection',
|
||||
@@ -112,11 +107,6 @@ export const zFieldType = z.enum([
|
||||
'LatentsPolymorphic',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'MetadataDict',
|
||||
'MetadataDictCollection',
|
||||
'MetadataItem',
|
||||
'MetadataItemCollection',
|
||||
'MetadataItemPolymorphic',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
'SDXLMainModelField',
|
||||
@@ -617,58 +607,6 @@ export type CollectionItemInputFieldValue = z.infer<
|
||||
typeof zCollectionItemInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItem = z.object({
|
||||
label: z.string(),
|
||||
value: z.any(),
|
||||
});
|
||||
export type MetadataItem = z.infer<typeof zMetadataItem>;
|
||||
|
||||
export const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItem'),
|
||||
value: zMetadataItem.optional(),
|
||||
});
|
||||
export type MetadataItemInputFieldValue = z.infer<
|
||||
typeof zMetadataItemInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemCollection'),
|
||||
value: z.array(zMetadataItem).optional(),
|
||||
});
|
||||
export type MetadataItemCollectionInputFieldValue = z.infer<
|
||||
typeof zMetadataItemCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemPolymorphic'),
|
||||
value: z.union([zMetadataItem, z.array(zMetadataItem)]).optional(),
|
||||
});
|
||||
export type MetadataItemPolymorphicInputFieldValue = z.infer<
|
||||
typeof zMetadataItemPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataDict = z.record(z.any());
|
||||
export type MetadataDict = z.infer<typeof zMetadataDict>;
|
||||
|
||||
export const zMetadataDictInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataDict'),
|
||||
value: zMetadataDict.optional(),
|
||||
});
|
||||
export type MetadataDictInputFieldValue = z.infer<
|
||||
typeof zMetadataDictInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataDictCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataDictCollection'),
|
||||
value: z.array(zMetadataDict).optional(),
|
||||
});
|
||||
export type MetadataDictCollectionInputFieldValue = z.infer<
|
||||
typeof zMetadataDictCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
@@ -707,13 +645,7 @@ export type SchedulerInputFieldValue = z.infer<
|
||||
typeof zSchedulerInputFieldValue
|
||||
>;
|
||||
|
||||
export const zAnyInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Any'),
|
||||
value: z.any().optional(),
|
||||
});
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zAnyInputFieldValue,
|
||||
zBoardInputFieldValue,
|
||||
zBooleanCollectionInputFieldValue,
|
||||
zBooleanInputFieldValue,
|
||||
@@ -758,11 +690,6 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
zMetadataItemInputFieldValue,
|
||||
zMetadataItemCollectionInputFieldValue,
|
||||
zMetadataItemPolymorphicInputFieldValue,
|
||||
zMetadataDictInputFieldValue,
|
||||
zMetadataDictCollectionInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@@ -775,11 +702,6 @@ export type InputFieldTemplateBase = {
|
||||
fieldKind: 'input';
|
||||
} & _InputField;
|
||||
|
||||
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'Any';
|
||||
default: undefined;
|
||||
};
|
||||
|
||||
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'integer';
|
||||
default: number;
|
||||
@@ -933,11 +855,6 @@ export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'UNetField';
|
||||
};
|
||||
|
||||
export type MetadataItemFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'UNetField';
|
||||
};
|
||||
|
||||
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ClipField';
|
||||
@@ -1050,35 +967,6 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'WorkflowField';
|
||||
};
|
||||
|
||||
export type MetadataItemInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataItem';
|
||||
};
|
||||
|
||||
export type MetadataItemCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataItemCollection';
|
||||
};
|
||||
|
||||
export type MetadataItemPolymorphicInputFieldTemplate = Omit<
|
||||
MetadataItemInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'MetadataItemPolymorphic';
|
||||
};
|
||||
|
||||
export type MetadataDictInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataDict';
|
||||
};
|
||||
|
||||
export type MetadataDictCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataDictCollection';
|
||||
};
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
@@ -1086,7 +974,6 @@ export type MetadataDictCollectionInputFieldTemplate =
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| AnyInputFieldTemplate
|
||||
| BoardInputFieldTemplate
|
||||
| BooleanCollectionInputFieldTemplate
|
||||
| BooleanPolymorphicInputFieldTemplate
|
||||
@@ -1130,12 +1017,7 @@ export type InputFieldTemplate =
|
||||
| StringInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| MetadataItemInputFieldTemplate
|
||||
| MetadataItemCollectionInputFieldTemplate
|
||||
| MetadataDictInputFieldTemplate
|
||||
| MetadataItemPolymorphicInputFieldTemplate
|
||||
| MetadataDictCollectionInputFieldTemplate;
|
||||
| VaeModelInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
@@ -1252,7 +1134,7 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
@@ -1279,7 +1161,15 @@ export const zCoreMetadata = z
|
||||
.nullish()
|
||||
.catch(null),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish().catch(null),
|
||||
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
.nullish()
|
||||
.catch(null),
|
||||
vae: zVaeModelField.nullish().catch(null),
|
||||
strength: z.number().nullish().catch(null),
|
||||
init_image: z.string().nullish().catch(null),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
@@ -9,61 +8,36 @@ import {
|
||||
isPolymorphicItemType,
|
||||
} from '../types/constants';
|
||||
import {
|
||||
AnyInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
BooleanCollectionInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
CollectionInputFieldTemplate,
|
||||
CollectionItemInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
ColorInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ConditioningField,
|
||||
ConditioningInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatCollectionInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
ImageInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerCollectionInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
InvocationSchemaObject,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
LatentsField,
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelInputFieldTemplate,
|
||||
MainModelInputFieldTemplate,
|
||||
MetadataDictCollectionInputFieldTemplate,
|
||||
MetadataDictInputFieldTemplate,
|
||||
MetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemInputFieldTemplate,
|
||||
MetadataItemPolymorphicInputFieldTemplate,
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringCollectionInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
@@ -71,7 +45,27 @@ import {
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
@@ -737,78 +731,6 @@ const buildCollectionItemInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildAnyInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): AnyInputFieldTemplate => {
|
||||
const template: AnyInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'Any',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemInputFieldTemplate => {
|
||||
const template: MetadataItemInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItem',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemCollectionInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemCollectionInputFieldTemplate => {
|
||||
const template: MetadataItemCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemCollection',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemPolymorphicInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => {
|
||||
const template: MetadataItemPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemPolymorphic',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataDictInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataDictInputFieldTemplate => {
|
||||
const template: MetadataDictInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataDict',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataDictCollectionInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataDictCollectionInputFieldTemplate => {
|
||||
const template: MetadataDictCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataDictCollection',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -948,7 +870,6 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
[key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate;
|
||||
} = {
|
||||
BoardField: buildBoardInputFieldTemplate,
|
||||
Any: buildAnyInputFieldTemplate,
|
||||
boolean: buildBooleanInputFieldTemplate,
|
||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||
@@ -982,11 +903,6 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MetadataItem: buildMetadataItemInputFieldTemplate,
|
||||
MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate,
|
||||
MetadataDict: buildMetadataDictInputFieldTemplate,
|
||||
MetadataDictCollection: buildMetadataDictCollectionInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||
|
||||
@@ -3,7 +3,6 @@ import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
const FIELD_VALUE_FALLBACK_MAP: {
|
||||
[key in FieldType]: InputFieldValue['value'];
|
||||
} = {
|
||||
Any: undefined,
|
||||
enum: '',
|
||||
BoardField: undefined,
|
||||
boolean: false,
|
||||
@@ -37,11 +36,6 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
MetadataItem: undefined,
|
||||
MetadataItemCollection: [],
|
||||
MetadataItemPolymorphic: undefined,
|
||||
MetadataDict: undefined,
|
||||
MetadataDictCollection: [],
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph, zControlField } from '../../types/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CONTROL_NET_COLLECT,
|
||||
METADATA_ACCUMULATOR,
|
||||
} from './constants';
|
||||
import { addMainMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
state: RootState,
|
||||
@@ -21,9 +23,12 @@ export const addControlNetToLinearGraph = (
|
||||
|
||||
const validControlNets = getValidControlNets(controlNets);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||
if (validControlNets.length) {
|
||||
const controlnets: ControlField[] = [];
|
||||
// We have multiple controlnets, add ControlNet collector
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
@@ -82,7 +87,15 @@ export const addControlNetToLinearGraph = (
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
controlnets.push(zControlField.parse(controlNetNode));
|
||||
if (metadataAccumulator?.controlnets) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const controlField = omit(controlNetNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as ControlField;
|
||||
metadataAccumulator.controlnets.push(controlField);
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
@@ -102,8 +115,6 @@ export const addControlNetToLinearGraph = (
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
addMainMetadata(graph, { controlnets });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -38,7 +38,15 @@ export const addIPAdapterToLinearGraph = (
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||
|
||||
// TODO: add metadata
|
||||
// if (metadataAccumulator?.ip_adapters) {
|
||||
// // metadata accumulator only needs the ip_adapter field - not the whole node
|
||||
// // extract what we need and add to the accumulator
|
||||
// const ipAdapterField = omit(ipAdapterNode, [
|
||||
// 'id',
|
||||
// 'type',
|
||||
// ]) as IPAdapterField;
|
||||
// metadataAccumulator.ip_adapters.push(ipAdapterField);
|
||||
// }
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
LoRAMetadataItem,
|
||||
NonNullableGraph,
|
||||
zLoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { LoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
LoraLoaderInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
import { addMainMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
@@ -34,29 +33,29 @@ export const addLoRAsToGraph = (
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
if (loraCount > 0) {
|
||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === modelLoaderNodeId &&
|
||||
['unet'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
||||
);
|
||||
}
|
||||
|
||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === modelLoaderNodeId &&
|
||||
['unet'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
||||
);
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: LoRAMetadataItem[] = [];
|
||||
|
||||
forEach(loras, (lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
@@ -70,12 +69,13 @@ export const addLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push(
|
||||
zLoRAMetadataItem.parse({
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator?.loras) {
|
||||
metadataAccumulator.loras.push({
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
})
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
@@ -182,6 +182,4 @@ export const addLoRAsToGraph = (
|
||||
lastLoraNodeId = currentLoraNodeId;
|
||||
currentLoraIndex += 1;
|
||||
});
|
||||
|
||||
addMainMetadata(graph, { loras: loraMetadata });
|
||||
};
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
LoRAMetadataItem,
|
||||
NonNullableGraph,
|
||||
zLoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
MetadataAccumulatorInvocation,
|
||||
SDXLLoraLoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
LORA_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addMainMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = (
|
||||
state: RootState,
|
||||
@@ -35,12 +34,9 @@ export const addSDXLLoRAsToGraph = (
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const loraMetadata: LoRAMetadataItem[] = [];
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
// Handle Seamless Plugs
|
||||
const unetLoaderId = modelLoaderNodeId;
|
||||
@@ -51,17 +47,22 @@ export const addSDXLLoRAsToGraph = (
|
||||
clipLoaderId = SDXL_MODEL_LOADER;
|
||||
}
|
||||
|
||||
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === unetLoaderId && ['unet'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
|
||||
) &&
|
||||
!(e.source.node_id === clipLoaderId && ['clip2'].includes(e.source.field))
|
||||
);
|
||||
if (loraCount > 0) {
|
||||
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === unetLoaderId && ['unet'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === clipLoaderId &&
|
||||
['clip2'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
@@ -79,12 +80,16 @@ export const addSDXLLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push(
|
||||
zLoRAMetadataItem.parse({
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator) {
|
||||
if (!metadataAccumulator.loras) {
|
||||
metadataAccumulator.loras = [];
|
||||
}
|
||||
metadataAccumulator.loras.push({
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
})
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
@@ -237,6 +242,4 @@ export const addSDXLLoRAsToGraph = (
|
||||
lastLoraNodeId = currentLoraNodeId;
|
||||
currentLoraIndex += 1;
|
||||
});
|
||||
|
||||
addMainMetadata(graph, { loras: loraMetadata });
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
MetadataAccumulatorInvocation,
|
||||
SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
@@ -11,6 +12,7 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_COMBINE,
|
||||
MASK_RESIZE_UP,
|
||||
METADATA_ACCUMULATOR,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
@@ -24,7 +26,6 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addMainMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
state: RootState,
|
||||
@@ -56,15 +57,21 @@ export const addSDXLRefinerToGraph = (
|
||||
return;
|
||||
}
|
||||
|
||||
addMainMetadata(graph, {
|
||||
refiner_model: refinerModel,
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
refiner_scheduler: refinerScheduler,
|
||||
refiner_start: refinerStart,
|
||||
refiner_steps: refinerSteps,
|
||||
});
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||
refinerPositiveAestheticScore;
|
||||
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||
refinerNegativeAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
metadataAccumulator.refiner_start = refinerStart;
|
||||
metadataAccumulator.refiner_steps = refinerSteps;
|
||||
}
|
||||
|
||||
const modelLoaderId = modelLoaderNodeId
|
||||
? modelLoaderNodeId
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { SaveImageInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NSFW_CHECKER,
|
||||
SAVE_IMAGE,
|
||||
WATERMARKER,
|
||||
} from './constants';
|
||||
import {
|
||||
MetadataAccumulatorInvocation,
|
||||
SaveImageInvocation,
|
||||
} from 'services/api/types';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
/**
|
||||
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
|
||||
@@ -32,6 +36,23 @@ export const addSaveImageNode = (
|
||||
|
||||
graph.nodes[SAVE_IMAGE] = saveImageNode;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const destination = {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'image',
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { SeamlessModeInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { addMainMetadata } from './metadata';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
@@ -32,11 +31,6 @@ export const addSeamlessToLinearGraph = (
|
||||
seamless_y: seamlessYAxis,
|
||||
} as SeamlessModeInvocation;
|
||||
|
||||
addMainMetadata(graph, {
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_y: seamlessYAxis,
|
||||
});
|
||||
|
||||
let denoisingNodeId = DENOISE_LATENTS;
|
||||
|
||||
if (
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@@ -13,6 +14,7 @@ import {
|
||||
INPAINT_IMAGE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
ONNX_MODEL_LOADER,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
@@ -24,7 +26,6 @@ import {
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { addMainMetadata } from './metadata';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
state: RootState,
|
||||
@@ -40,6 +41,9 @@ export const addVAEToGraph = (
|
||||
);
|
||||
|
||||
const isAutoVae = !vae;
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (!isAutoVae) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
@@ -177,7 +181,7 @@ export const addVAEToGraph = (
|
||||
}
|
||||
}
|
||||
|
||||
if (vae) {
|
||||
addMainMetadata(graph, { vae });
|
||||
if (vae && metadataAccumulator) {
|
||||
metadataAccumulator.vae = vae;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -5,8 +5,14 @@ import {
|
||||
ImageNSFWBlurInvocation,
|
||||
ImageWatermarkInvocation,
|
||||
LatentsToImageInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NSFW_CHECKER,
|
||||
WATERMARKER,
|
||||
} from './constants';
|
||||
|
||||
export const addWatermarkerToGraph = (
|
||||
state: RootState,
|
||||
@@ -26,6 +32,10 @@ export const addWatermarkerToGraph = (
|
||||
| ImageNSFWBlurInvocation
|
||||
| undefined;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (!nodeToAddTo) {
|
||||
// something has gone terribly awry
|
||||
return;
|
||||
@@ -70,4 +80,17 @@ export const addWatermarkerToGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (metadataAccumulator) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: WATERMARKER,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||
import {
|
||||
ESRGANInvocation,
|
||||
Graph,
|
||||
ESRGANInvocation,
|
||||
SaveImageInvocation,
|
||||
} from 'services/api/types';
|
||||
import { REALESRGAN as ESRGAN, SAVE_IMAGE } from './constants';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
|
||||
type Arg = {
|
||||
image_name: string;
|
||||
@@ -56,9 +55,5 @@ export const buildAdHocUpscaleGraph = ({
|
||||
],
|
||||
};
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
model: esrganModelName,
|
||||
});
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
||||
@@ -19,12 +19,12 @@ import {
|
||||
IMG2IMG_RESIZE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
@@ -307,7 +307,10 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@@ -321,10 +324,13 @@ export const buildCanvasImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
});
|
||||
};
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
IMAGE_TO_LATENTS,
|
||||
IMG2IMG_RESIZE,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@@ -26,7 +27,6 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
@@ -318,7 +318,10 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@@ -332,8 +335,22 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -17,6 +17,7 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
@@ -28,7 +29,6 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@@ -300,7 +300,10 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@@ -314,6 +317,20 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -20,13 +20,13 @@ import {
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@@ -288,7 +288,10 @@ export const buildCanvasTextToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@@ -302,7 +305,21 @@ export const buildCanvasTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -4,20 +4,13 @@ import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { range, unset } from 'lodash-es';
|
||||
import { components } from 'services/api/schema';
|
||||
import { Batch, BatchConfig, MetadataItemInvocation } from 'services/api/types';
|
||||
import { Batch, BatchConfig } from 'services/api/types';
|
||||
import {
|
||||
BATCH_PROMPT,
|
||||
BATCH_SEED,
|
||||
BATCH_STYLE_PROMPT,
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
import {
|
||||
addBatchMetadataNodeToGraph,
|
||||
removeMetadataFromMainMetadataNode,
|
||||
} from './metadata';
|
||||
|
||||
export const prepareLinearUIBatch = (
|
||||
state: RootState,
|
||||
@@ -30,27 +23,8 @@ export const prepareLinearUIBatch = (
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
const seedMetadataItemNode: MetadataItemInvocation = {
|
||||
id: BATCH_SEED,
|
||||
type: 'metadata_item',
|
||||
label: 'seed',
|
||||
};
|
||||
|
||||
const promptMetadataItemNode: MetadataItemInvocation = {
|
||||
id: BATCH_PROMPT,
|
||||
type: 'metadata_item',
|
||||
label: 'positive_prompt',
|
||||
};
|
||||
|
||||
const stylePromptMetadataItemNode: MetadataItemInvocation = {
|
||||
id: BATCH_STYLE_PROMPT,
|
||||
type: 'metadata_item',
|
||||
label: 'positive_style_prompt',
|
||||
};
|
||||
|
||||
const itemNodesIds: string[] = [];
|
||||
|
||||
if (prompts.length === 1) {
|
||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'seed');
|
||||
const seeds = generateSeeds({
|
||||
count: iterations,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
@@ -66,15 +40,13 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
removeMetadataFromMainMetadataNode(graph, 'seed');
|
||||
itemNodesIds.push(BATCH_SEED);
|
||||
graph.nodes[BATCH_SEED] = seedMetadataItemNode;
|
||||
zipped.push({
|
||||
node_path: BATCH_SEED,
|
||||
field_name: 'value',
|
||||
items: seeds,
|
||||
});
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
zipped.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
zipped.push({
|
||||
@@ -105,15 +77,13 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
removeMetadataFromMainMetadataNode(graph, 'seed');
|
||||
itemNodesIds.push(BATCH_SEED);
|
||||
graph.nodes[BATCH_SEED] = seedMetadataItemNode;
|
||||
firstBatchDatumList.push({
|
||||
node_path: BATCH_SEED,
|
||||
field_name: 'value',
|
||||
items: seeds,
|
||||
});
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
firstBatchDatumList.push({
|
||||
@@ -136,17 +106,13 @@ export const prepareLinearUIBatch = (
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
removeMetadataFromMainMetadataNode(graph, 'seed');
|
||||
itemNodesIds.push(BATCH_SEED);
|
||||
graph.nodes[BATCH_SEED] = seedMetadataItemNode;
|
||||
secondBatchDatumList.push({
|
||||
node_path: BATCH_SEED,
|
||||
field_name: 'value',
|
||||
items: seeds,
|
||||
});
|
||||
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
secondBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
secondBatchDatumList.push({
|
||||
node_path: CANVAS_COHERENCE_NOISE,
|
||||
@@ -171,15 +137,13 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
removeMetadataFromMainMetadataNode(graph, 'positive_prompt');
|
||||
itemNodesIds.push(BATCH_PROMPT);
|
||||
graph.nodes[BATCH_PROMPT] = promptMetadataItemNode;
|
||||
firstBatchDatumList.push({
|
||||
node_path: BATCH_PROMPT,
|
||||
field_name: 'value',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
field_name: 'positive_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'positive_style_prompt');
|
||||
@@ -196,22 +160,18 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
removeMetadataFromMainMetadataNode(graph, 'positive_style_prompt');
|
||||
itemNodesIds.push(BATCH_STYLE_PROMPT);
|
||||
graph.nodes[BATCH_STYLE_PROMPT] = stylePromptMetadataItemNode;
|
||||
firstBatchDatumList.push({
|
||||
node_path: BATCH_STYLE_PROMPT,
|
||||
field_name: 'value',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
field_name: 'positive_style_prompt',
|
||||
items: stylePrompts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.push(firstBatchDatumList);
|
||||
}
|
||||
|
||||
addBatchMetadataNodeToGraph(graph, itemNodesIds);
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend,
|
||||
batch: {
|
||||
|
||||
@@ -20,13 +20,13 @@ import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RESIZE,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
@@ -310,7 +310,10 @@ export const buildLinearImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@@ -322,9 +325,23 @@ export const buildLinearImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.imageName,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -17,6 +17,7 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@@ -28,7 +29,6 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
@@ -330,7 +330,10 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@@ -342,10 +345,24 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
strength,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
strength: strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -10,9 +10,9 @@ import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@@ -224,7 +224,10 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@@ -236,8 +239,22 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -13,12 +13,12 @@ import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addMainMetadataNodeToGraph } from './metadata';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
@@ -232,7 +232,10 @@ export const buildLinearTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
addMainMetadataNodeToGraph(graph, {
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@@ -244,7 +247,21 @@ export const buildLinearTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
||||
@@ -50,15 +50,7 @@ export const IP_ADAPTER = 'ip_adapter';
|
||||
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||
export const IMAGE_COLLECTION = 'image_collection';
|
||||
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
||||
export const METADATA = 'metadata';
|
||||
export const BATCH_METADATA = 'batch_metadata';
|
||||
export const BATCH_METADATA_COLLECT = 'batch_metadata_collect';
|
||||
export const BATCH_SEED = 'batch_seed';
|
||||
export const BATCH_PROMPT = 'batch_prompt';
|
||||
export const BATCH_STYLE_PROMPT = 'batch_style_prompt';
|
||||
export const METADATA_COLLECT = 'metadata_collect';
|
||||
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
export const MERGE_METADATA = 'merge_metadata';
|
||||
export const REALESRGAN = 'esrgan';
|
||||
export const DIVIDE = 'divide';
|
||||
export const SCALE = 'scale_image';
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { map } from 'lodash-es';
|
||||
import { MetadataInvocationAsCollection } from 'services/api/types';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import {
|
||||
BATCH_METADATA,
|
||||
BATCH_METADATA_COLLECT,
|
||||
MERGE_METADATA,
|
||||
METADATA,
|
||||
METADATA_COLLECT,
|
||||
SAVE_IMAGE,
|
||||
} from './constants';
|
||||
|
||||
export const addMainMetadataNodeToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
metadata: JsonObject
|
||||
): void => {
|
||||
graph.nodes[METADATA] = {
|
||||
id: METADATA,
|
||||
type: 'metadata',
|
||||
items: map(metadata, (value, label) => ({ label, value })),
|
||||
};
|
||||
|
||||
graph.nodes[METADATA_COLLECT] = {
|
||||
id: METADATA_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
|
||||
graph.nodes[MERGE_METADATA] = {
|
||||
id: MERGE_METADATA,
|
||||
type: 'merge_metadata_dict',
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA,
|
||||
field: 'metadata_dict',
|
||||
},
|
||||
destination: {
|
||||
node_id: METADATA_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: MERGE_METADATA,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MERGE_METADATA,
|
||||
field: 'metadata_dict',
|
||||
},
|
||||
destination: {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
export const addMainMetadata = (
|
||||
graph: NonNullableGraph,
|
||||
metadata: JsonObject
|
||||
): void => {
|
||||
const metadataNode = graph.nodes[METADATA] as
|
||||
| MetadataInvocationAsCollection
|
||||
| undefined;
|
||||
|
||||
if (!metadataNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
metadataNode.items.push(
|
||||
...map(metadata, (value, label) => ({ label, value }))
|
||||
);
|
||||
};
|
||||
|
||||
export const removeMetadataFromMainMetadataNode = (
|
||||
graph: NonNullableGraph,
|
||||
label: string
|
||||
): void => {
|
||||
const metadataNode = graph.nodes[METADATA] as
|
||||
| MetadataInvocationAsCollection
|
||||
| undefined;
|
||||
|
||||
if (!metadataNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
metadataNode.items = metadataNode.items.filter(
|
||||
(item) => item.label !== label
|
||||
);
|
||||
};
|
||||
|
||||
export const addBatchMetadataNodeToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
itemNodeIds: string[]
|
||||
) => {
|
||||
graph.nodes[BATCH_METADATA] = {
|
||||
id: BATCH_METADATA,
|
||||
type: 'metadata',
|
||||
};
|
||||
graph.nodes[BATCH_METADATA_COLLECT] = {
|
||||
id: BATCH_METADATA_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
|
||||
itemNodeIds.forEach((id) => {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: id,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: BATCH_METADATA_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: BATCH_METADATA_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: BATCH_METADATA,
|
||||
field: 'items',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: BATCH_METADATA,
|
||||
field: 'metadata_dict',
|
||||
},
|
||||
destination: {
|
||||
node_id: METADATA_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -4,6 +4,7 @@ import { reduce } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import {
|
||||
FieldType,
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
@@ -15,11 +16,18 @@ import {
|
||||
} from '../types/types';
|
||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata', 'use_cache'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
const RESERVED_FIELD_TYPES = ['IsIntermediate', 'WorkflowField'];
|
||||
const RESERVED_FIELD_TYPES = [
|
||||
'WorkflowField',
|
||||
'MetadataField',
|
||||
'IsIntermediate',
|
||||
];
|
||||
|
||||
const invocationDenylist: AnyInvocationType[] = ['graph'];
|
||||
const invocationDenylist: AnyInvocationType[] = [
|
||||
'graph',
|
||||
'metadata_accumulator',
|
||||
];
|
||||
|
||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
@@ -34,7 +42,7 @@ const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
const isReservedFieldType = (fieldType: string) => {
|
||||
const isReservedFieldType = (fieldType: FieldType) => {
|
||||
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
|
||||
return true;
|
||||
}
|
||||
@@ -78,7 +86,6 @@ export const parseSchema = (
|
||||
const tags = schema.tags ?? [];
|
||||
const description = schema.description ?? '';
|
||||
const version = schema.version;
|
||||
let withWorkflow = false;
|
||||
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
@@ -105,7 +112,7 @@ export const parseSchema = (
|
||||
|
||||
const fieldType = getFieldType(property);
|
||||
|
||||
if (!fieldType) {
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
@@ -113,16 +120,11 @@ export const parseSchema = (
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Missing input field type'
|
||||
'Skipping unknown input field type'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (fieldType === 'WorkflowField') {
|
||||
withWorkflow = true;
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isReservedFieldType(fieldType)) {
|
||||
logger('nodes').trace(
|
||||
{
|
||||
@@ -131,20 +133,7 @@ export const parseSchema = (
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping reserved input field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping unknown input field type: ${fieldType}`
|
||||
'Skipping reserved field type'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
@@ -157,7 +146,7 @@ export const parseSchema = (
|
||||
);
|
||||
|
||||
if (!field) {
|
||||
logger('nodes').warn(
|
||||
logger('nodes').debug(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
@@ -258,7 +247,6 @@ export const parseSchema = (
|
||||
inputs,
|
||||
outputs,
|
||||
useCache,
|
||||
withWorkflow,
|
||||
};
|
||||
|
||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Flex, Skeleton } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { COLUMN_WIDTHS } from './constants';
|
||||
|
||||
const QueueItemSkeleton = () => {
|
||||
return (
|
||||
<Flex alignItems="center" p={1.5} gap={4} minH={9} h="full" w="full">
|
||||
<Flex
|
||||
w={COLUMN_WIDTHS.number}
|
||||
justifyContent="flex-end"
|
||||
alignItems="center"
|
||||
>
|
||||
<Skeleton w="full" h="full">
|
||||
|
||||
</Skeleton>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||
<Skeleton w="full" h="full">
|
||||
|
||||
</Skeleton>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.time} alignItems="center">
|
||||
<Skeleton w="full" h="full">
|
||||
|
||||
</Skeleton>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
|
||||
<Skeleton w="full" h="full">
|
||||
|
||||
</Skeleton>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.fieldValues} alignItems="center" flexGrow={1}>
|
||||
<Skeleton w="full" h="full">
|
||||
|
||||
</Skeleton>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueItemSkeleton);
|
||||
@@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import {
|
||||
listCursorChanged,
|
||||
listPriorityChanged,
|
||||
@@ -85,7 +86,7 @@ const QueueList = () => {
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const { data: listQueueItemsData } = useListQueueItemsQuery({
|
||||
const { data: listQueueItemsData, isLoading } = useListQueueItemsQuery({
|
||||
cursor: listCursor,
|
||||
priority: listPriority,
|
||||
});
|
||||
@@ -125,36 +126,40 @@ const QueueList = () => {
|
||||
[openQueueItems, toggleQueueItem]
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
}
|
||||
|
||||
if (!queueItems.length) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Heading color="base.400" _dark={{ color: 'base.500' }}>
|
||||
{t('queue.queueEmpty')}
|
||||
</Heading>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="column">
|
||||
{queueItems.length ? (
|
||||
<>
|
||||
<QueueListHeader />
|
||||
<Flex
|
||||
ref={rootRef}
|
||||
w="full"
|
||||
h="full"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Virtuoso<SessionQueueItemDTO, ListContext>
|
||||
data={queueItems}
|
||||
endReached={handleLoadMore}
|
||||
scrollerRef={setScroller as TableVirtuosoScrollerRef}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
context={context}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Heading color="base.400" _dark={{ color: 'base.500' }}>
|
||||
{t('queue.queueEmpty')}
|
||||
</Heading>
|
||||
</Flex>
|
||||
)}
|
||||
<QueueListHeader />
|
||||
<Flex
|
||||
ref={rootRef}
|
||||
w="full"
|
||||
h="full"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Virtuoso<SessionQueueItemDTO, ListContext>
|
||||
data={queueItems}
|
||||
endReached={handleLoadMore}
|
||||
scrollerRef={setScroller as TableVirtuosoScrollerRef}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
context={context}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
zWorkflow,
|
||||
} from 'features/nodes/types/types';
|
||||
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
|
||||
import { keyBy } from 'lodash-es';
|
||||
@@ -24,6 +23,7 @@ import {
|
||||
ListImagesArgs,
|
||||
OffsetPaginatedResults_ImageDTO_,
|
||||
PostUploadAction,
|
||||
UnsafeImageMetadata,
|
||||
} from '../types';
|
||||
import {
|
||||
getCategories,
|
||||
@@ -33,7 +33,6 @@ import {
|
||||
imagesSelectors,
|
||||
} from '../util';
|
||||
import { boardsApi } from './boards';
|
||||
import { logger } from 'app/logging/logger';
|
||||
|
||||
export const imagesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@@ -114,33 +113,11 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadata: build.query<ImageMetadataAndWorkflow, string>({
|
||||
getImageMetadata: build.query<UnsafeImageMetadata, string>({
|
||||
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
|
||||
providesTags: (result, error, image_name) => [
|
||||
{ type: 'ImageMetadata', id: image_name },
|
||||
],
|
||||
transformResponse: (
|
||||
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
|
||||
) => {
|
||||
const imageMetadataAndWorkflow: ImageMetadataAndWorkflow = {};
|
||||
if (response?.metadata) {
|
||||
const metadataResult = zCoreMetadata.safeParse(response.metadata);
|
||||
if (metadataResult.success) {
|
||||
imageMetadataAndWorkflow.metadata = metadataResult.data;
|
||||
} else {
|
||||
logger('images').warn('Problem parsing metadata');
|
||||
}
|
||||
}
|
||||
if (response?.workflow) {
|
||||
const workflowResult = zWorkflow.safeParse(response.workflow);
|
||||
if (workflowResult.success) {
|
||||
imageMetadataAndWorkflow.workflow = workflowResult.data;
|
||||
} else {
|
||||
logger('images').warn('Problem parsing workflow');
|
||||
}
|
||||
}
|
||||
return imageMetadataAndWorkflow;
|
||||
},
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadataFromFile: build.query<
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
ThunkDispatch,
|
||||
createEntityAdapter,
|
||||
} from '@reduxjs/toolkit';
|
||||
import { $queueId } from 'features/queue/store/nanoStores';
|
||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
||||
import { listParamsReset } from 'features/queue/store/queueSlice';
|
||||
import queryString from 'query-string';
|
||||
import { ApiTagDescription, api } from '..';
|
||||
|
||||
1540
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
1540
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@@ -1,5 +1,5 @@
|
||||
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { $queueId } from 'features/queue/store/nanoStores';
|
||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { $client } from 'services/api/client';
|
||||
import { paths } from 'services/api/schema';
|
||||
|
||||
@@ -147,15 +147,6 @@ export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
|
||||
export type SaveImageInvocation = s['SaveImageInvocation'];
|
||||
export type MetadataInvocation = s['MetadataInvocation'];
|
||||
export type MetadataInvocationAsCollection = Omit<
|
||||
s['MetadataInvocation'],
|
||||
'items'
|
||||
> & {
|
||||
items: s['MetadataItem'][];
|
||||
};
|
||||
export type MetadataItemInvocation = s['MetadataItemInvocation'];
|
||||
export type MergeMetadataDictInvocation = s['MergeMetadataDictInvocation'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = s['ControlNetInvocation'];
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { AppDispatch, RootState } from 'app/store/store';
|
||||
import { $queueId } from 'features/queue/store/nanoStores';
|
||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { Socket } from 'socket.io-client';
|
||||
|
||||
@@ -72,7 +72,7 @@ dependencies = [
|
||||
"realesrgan",
|
||||
"requests~=2.28.2",
|
||||
"rich~=13.3",
|
||||
"safetensors==0.3.1",
|
||||
"safetensors~=0.3.1",
|
||||
"scikit-image~=0.21.0",
|
||||
"semver~=3.0.1",
|
||||
"send2trash",
|
||||
|
||||
@@ -9,12 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
IntegerInvocation,
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.graph import (
|
||||
@@ -31,11 +26,8 @@ from invokeai.app.services.graph import (
|
||||
)
|
||||
|
||||
from .test_nodes import (
|
||||
AnyTypeTestInvocation,
|
||||
ImageToImageTestInvocation,
|
||||
ListPassThroughInvocation,
|
||||
PolymorphicStringTestInvocation,
|
||||
PromptCollectionTestInvocation,
|
||||
PromptTestInvocation,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
@@ -699,146 +691,6 @@ def test_ints_do_not_accept_floats():
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_single():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana")
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_collection_of_same_base_type():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_collection_of_different_base_type():
|
||||
g = Graph()
|
||||
n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = PolymorphicStringTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_integer():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_string():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana sundae")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = AnyTypeTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_prompt_collection():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_any():
|
||||
g = Graph()
|
||||
n1 = AnyTypeTestInvocation(id="1")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="""We need to update the validation for Collect -> Iterate to traverse to the Iterate
|
||||
node's output and compare that against the item type of the Collect node's collection. Until
|
||||
then, Collect nodes may not output into Iterate nodes."""
|
||||
)
|
||||
def test_iterate_accepts_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = IterateInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# eventually this should succeed
|
||||
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
|
||||
@@ -81,29 +81,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
@invocation_output("test_any_output")
|
||||
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
value: Any = Field()
|
||||
|
||||
|
||||
@invocation("test_any")
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = Field(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
|
||||
|
||||
@invocation("test_polymorphic")
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||
|
||||
|
||||
# Importing these must happen after test invocations are defined or they won't register
|
||||
from invokeai.app.services.events import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection # noqa: E402
|
||||
|
||||
Reference in New Issue
Block a user