mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
60 Commits
feat/api/i
...
feat/contr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4a45341c8 | ||
|
|
4ca325e8e6 | ||
|
|
6b8e88ad7f | ||
|
|
0497bea264 | ||
|
|
b8e32fa459 | ||
|
|
34ebee67b7 | ||
|
|
e0c998d192 | ||
|
|
b51e9a6bdb | ||
|
|
09f396ce84 | ||
|
|
abee37eab3 | ||
|
|
42e48b2bef | ||
|
|
70ece4364c | ||
|
|
f9d5f9d52c | ||
|
|
587297878a | ||
|
|
b4c998a9ae | ||
|
|
88e8e3977b | ||
|
|
24b86cffe9 | ||
|
|
a1773197e9 | ||
|
|
6c53abc034 | ||
|
|
eb7047b21d | ||
|
|
43419ac761 | ||
|
|
5cd0e90816 | ||
|
|
cfd49e3921 | ||
|
|
a8e0490133 | ||
|
|
1e08d865c9 | ||
|
|
f8bb650cc1 | ||
|
|
2cee8bebb2 | ||
|
|
ade4ec5fd8 | ||
|
|
70ffd6b03f | ||
|
|
6c551df311 | ||
|
|
24f605629e | ||
|
|
2af1ec9d02 | ||
|
|
79d53341de | ||
|
|
e40b3506c4 | ||
|
|
33912382e3 | ||
|
|
d282810e53 | ||
|
|
9df502fc77 | ||
|
|
705573f0a8 | ||
|
|
1878ea94f6 | ||
|
|
4ba5086b9a | ||
|
|
4a991b4daa | ||
|
|
80474d26f9 | ||
|
|
9a77bd9140 | ||
|
|
14cdc800c3 | ||
|
|
9cfbea4c25 | ||
|
|
5fe674e223 | ||
|
|
32200efce8 | ||
|
|
68a02da990 | ||
|
|
5b20766ea3 | ||
|
|
9a914250a0 | ||
|
|
0e3106f631 | ||
|
|
de3e6cdb02 | ||
|
|
8495764d45 | ||
|
|
8b7fac75ed | ||
|
|
9e0e26f4c4 | ||
|
|
46cac6468e | ||
|
|
2a814d886b | ||
|
|
fd715026a7 | ||
|
|
7bce455d16 | ||
|
|
68405910ba |
@@ -70,27 +70,25 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_origin}/{image_name}",
|
||||
"/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
@@ -99,32 +97,29 @@ async def update_image(
|
||||
"""Updates an image"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(
|
||||
image_origin, image_name, image_changes
|
||||
)
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/metadata",
|
||||
"/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}",
|
||||
"/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -136,15 +131,12 @@ async def get_image_metadata(
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_origin: ResourceOrigin = Path(
|
||||
description="The type of full-resolution image file to get"
|
||||
),
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@@ -160,7 +152,7 @@ async def get_image_full(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/thumbnail",
|
||||
"/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -172,14 +164,13 @@ async def get_image_full(
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_origin, image_name, thumbnail=True
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@@ -192,25 +183,21 @@ async def get_image_thumbnail(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/urls",
|
||||
"/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name
|
||||
)
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name, thumbnail=True
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
from builtins import float, bool
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
@@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
@@ -104,6 +105,8 @@ class ControlField(BaseModel):
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
@@ -144,11 +147,11 @@ class ControlNetInvocation(BaseInvocation):
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
@@ -166,7 +169,6 @@ class ControlNetInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
@@ -174,6 +176,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -193,9 +196,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
@@ -216,10 +217,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
)
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
|
||||
@@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
@@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.control_image.image_origin, self.control_image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
@@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
@@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
@@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
@@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
@@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
@@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(
|
||||
self.base_image.image_origin, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
@@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(
|
||||
self.image1.image_origin, self.image1.image_name
|
||||
)
|
||||
image2 = context.services.images.get_pil_image(
|
||||
self.image2.image_origin, self.image2.image_name
|
||||
)
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
@@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
@@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
@@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
@@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
@@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
@@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
@@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
@@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
@@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
@@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
@@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
from compel import Compel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
@@ -282,19 +282,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
if control_input is None:
|
||||
# print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
# print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
# print("control input is ControlField")
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
# print("control input is list[ControlField]")
|
||||
control_list = control_input
|
||||
else:
|
||||
# print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
control_data = None
|
||||
@@ -321,8 +316,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||
control_image_field.image_name)
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@@ -337,12 +331,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent)
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
@@ -502,10 +499,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -601,9 +595,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
@@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
|
||||
'prompt',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PromptCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a collection of prompts"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||
|
||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
||||
count: int = Field(description="The size of the prompt collection")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
|
||||
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(
|
||||
default=False, description="Whether to use the combinatorial generator"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
|
||||
@@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
@@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
@@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_origin: ResourceOrigin = Field(
|
||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
@@ -40,14 +39,12 @@ class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@@ -62,7 +59,6 @@ class ImageFileStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
@@ -71,7 +67,7 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@@ -79,31 +75,26 @@ class ImageFileStorageBase(ABC):
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__output_folder: Path
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, PILImageType]
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
def __init__(self, output_folder: str | Path):
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
@@ -117,13 +108,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
self.__validate_storage_folders()
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
@@ -133,7 +124,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
@@ -142,20 +133,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
if image_path.exists():
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
if thumbnail_path.exists():
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
@@ -163,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
path = self.__output_folder / image_name
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
path = self.__thumbnails_folder / thumbnail_name
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
return path
|
||||
|
||||
return abspath
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
def validate_path(self, path: str | Path) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
"""Checks if the required output folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
||||
for folder in folders:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||
def __get_cache(self, image_name: Path) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: PILImageType):
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
@@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
@@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
def get(self, image_name: str) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
@@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
|
||||
@@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
@@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
def get_path(self, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
def delete(self, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_origin, image_name, True
|
||||
)
|
||||
image_url = self._services.urls.get_image_url(image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
@@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
self._services.records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
@@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
return self._services.files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
return self._services.records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
image_record = self._services.records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
return self._services.files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
self._services.files.delete(image_name)
|
||||
self._services.records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
@@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
__output_folder: str | Path
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, output_folder: str | Path):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
latent_path.unlink()
|
||||
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
@@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
return f"{self._base_url}/images/{image_basename}"
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@@ -46,6 +46,7 @@ from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
ControlNetData,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
@@ -214,13 +215,6 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: Union[float, List[float]]= Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@@ -656,69 +650,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
||||
# if conditioning_data.guidance_scale > 1.0:
|
||||
if conditioning_data.guidance_scale is not None:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
# print("controlnet", i, "==>", type(control_datum))
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
# print("running controlnet", i, "for step", step_index)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input,
|
||||
t,
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
x=unet_latent_input,
|
||||
sigma=t,
|
||||
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||
conditioning=conditioning_data.text_embeddings,
|
||||
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||
control_data=control_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -1038,6 +981,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced"
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
@@ -1068,6 +1012,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
#cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
#if do_classifier_free_guidance and not cfg_injection:
|
||||
# image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
||||
@@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, ControlNetData
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@@ -40,6 +43,17 @@ class PostprocessingSettings:
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
# TODO: pydantic Field work with dataclasses?
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor = Field(default=None)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
control_mode: str = Field(default="balanced")
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@@ -182,8 +196,9 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
control_data: Optional[List[ControlNetData]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -213,31 +228,56 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
)
|
||||
|
||||
if self.sequential_guidance:
|
||||
down_block_res_samples, mid_block_res_sample = self._run_controlnet_sequentially(
|
||||
unconditioning=unconditioning,
|
||||
conditioning=conditioning,
|
||||
control_data=control_data,
|
||||
sample=x,
|
||||
timestep=sigma,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = self._run_controlnet_normally(
|
||||
unconditioning=unconditioning,
|
||||
conditioning=conditioning,
|
||||
control_data=control_data,
|
||||
sample=x,
|
||||
timestep=sigma,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
elif True: #self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -245,7 +285,10 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
@@ -293,16 +336,160 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
def _run_controlnet_normally(
|
||||
self,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
control_data: List[ControlNetData],
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
):
|
||||
if control_data is None:
|
||||
return (None, None)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if cfg_injection:
|
||||
control_sample = sample
|
||||
control_timestep = timestep
|
||||
control_image_tensor = control_datum.image_tensor
|
||||
encoder_hidden_states = conditioning # TODO: ask bug
|
||||
else:
|
||||
control_sample = torch.cat([sample] * 2)
|
||||
control_timestep = torch.cat([timestep] * 2)
|
||||
control_image_tensor = torch.cat([control_datum.image_tensor] * 2)
|
||||
encoder_hidden_states = torch.cat([unconditioning, conditioning])
|
||||
|
||||
if isinstance(control_datum.weight, list):
|
||||
weight = control_datum.weight[step_index]
|
||||
else:
|
||||
weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=control_sample,
|
||||
timestep=control_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if cfg_injection:
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def _run_controlnet_sequentially(
|
||||
self,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
control_data: List[ControlNetData],
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
):
|
||||
if control_data is None:
|
||||
return (None, None)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if isinstance(control_datum.weight, list):
|
||||
weight = control_datum.weight[step_index]
|
||||
else:
|
||||
weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
cond_down_samples, cond_mid_sample = control_datum.model(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=conditioning, # TODO: ask bug
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if cfg_injection:
|
||||
uncond_down_samples = [torch.zeros_like(d) for d in cond_down_samples]
|
||||
uncond_mid_sample = torch.zeros_like(cond_mid_sample)
|
||||
|
||||
else:
|
||||
uncond_down_samples, uncond_mid_sample = control_datum.model(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=unconditioning,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
down_samples = [torch.cat([ud, cd]) for ud, cd in zip(uncond_down_samples, cond_down_samples)]
|
||||
mid_sample = torch.cat([uncond_mid_sample, cond_mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
)
|
||||
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings, **kwargs)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# TODO: check if this still present
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@@ -310,15 +497,43 @@ class InvokeAIDiffuserComponent:
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
down_block_additional_residuals, # from controlnet(s)
|
||||
mid_block_additional_residual, # from controlnet(s)
|
||||
**kwargs,
|
||||
):
|
||||
# split controlnet data to cond and uncond
|
||||
if down_block_additional_residuals is None:
|
||||
uncond_down_block_res_samples = None
|
||||
cond_down_block_res_samples = None
|
||||
uncond_mid_block_res_sample = None
|
||||
cond_mid_block_res_sample = None
|
||||
|
||||
else:
|
||||
uncond_down_block_res_samples = []
|
||||
cond_down_block_res_samples = []
|
||||
for d in down_block_additional_residuals:
|
||||
ud, cd = d.chunk(2)
|
||||
uncond_down_block_res_samples.append(ud)
|
||||
cond_down_block_res_samples.append(cd)
|
||||
|
||||
uncond_mid_block_res_sample, cond_mid_block_res_sample = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, unconditioning, **kwargs,
|
||||
down_block_additional_residuals=uncond_down_block_res_samples,
|
||||
mid_block_additional_residual=uncond_mid_block_res_sample,
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, conditioning, **kwargs,
|
||||
down_block_additional_residuals=cond_down_block_res_samples,
|
||||
mid_block_additional_residual=cond_mid_block_res_sample,
|
||||
)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# TODO: check if still present
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@@ -62,10 +62,12 @@
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@dnd-kit/core": "^6.0.8",
|
||||
"@dnd-kit/modifiers": "^6.0.1",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/react": "^11.11.1",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@floating-ui/react-dom": "^2.0.0",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@mantine/core": "^6.0.13",
|
||||
"@mantine/hooks": "^6.0.13",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
|
||||
@@ -524,7 +524,8 @@
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
||||
@@ -3,11 +3,11 @@ import {
|
||||
createLocalStorageManager,
|
||||
extendTheme,
|
||||
} from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ReactNode, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
||||
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
||||
@@ -15,6 +15,8 @@ import { lightThemeColors } from 'theme/colors/lightTheme';
|
||||
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
||||
|
||||
import '@fontsource/inter/variable.css';
|
||||
import { MantineProvider } from '@mantine/core';
|
||||
import { mantineTheme } from 'mantine-theme/theme';
|
||||
import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import 'theme/css/overlayscrollbars.css';
|
||||
|
||||
@@ -51,9 +53,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
}, [direction]);
|
||||
|
||||
return (
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
<MantineProvider withGlobalStyles theme={mantineTheme}>
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
</MantineProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ export const SCHEDULERS = [
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||
{ key: '2x', value: 2 },
|
||||
{ key: '4x', value: 4 },
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
{ label: '2x', value: '2' },
|
||||
{ label: '4x', value: '4' },
|
||||
];
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
|
||||
|
||||
@@ -34,10 +34,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
[controlNet.processorNode.id]: {
|
||||
...controlNet.processorNode,
|
||||
is_intermediate: true,
|
||||
image: pick(controlNet.controlImage, [
|
||||
'image_name',
|
||||
'image_origin',
|
||||
]),
|
||||
image: pick(controlNet.controlImage, ['image_name']),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -25,7 +25,7 @@ export const addRequestedImageDeletionListener = () => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { image, imageUsage } = action.payload;
|
||||
|
||||
const { image_name, image_origin } = image;
|
||||
const { image_name } = image;
|
||||
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
@@ -79,9 +79,7 @@ export const addRequestedImageDeletionListener = () => {
|
||||
dispatch(imageRemoved(image_name));
|
||||
|
||||
// Delete from server
|
||||
dispatch(
|
||||
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
|
||||
);
|
||||
dispatch(imageDeleted({ imageName: image_name }));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -20,7 +20,6 @@ export const addImageMetadataReceivedFulfilledListener = () => {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: image.image_name,
|
||||
imageOrigin: image.image_origin,
|
||||
requestBody: { is_intermediate: false },
|
||||
})
|
||||
);
|
||||
|
||||
@@ -36,13 +36,12 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||
const { image_name, image_origin } = result.image;
|
||||
const { image_name } = result.image;
|
||||
|
||||
// Get its metadata
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
imageOrigin: image_origin,
|
||||
})
|
||||
);
|
||||
|
||||
|
||||
@@ -11,12 +11,11 @@ export const addStagingAreaImageSavedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: stagingAreaImageSaved,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const { image_name, image_origin } = action.payload;
|
||||
const { image_name } = action.payload;
|
||||
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: image_name,
|
||||
imageOrigin: image_origin,
|
||||
requestBody: {
|
||||
is_intermediate: false,
|
||||
},
|
||||
|
||||
@@ -80,11 +80,10 @@ export const addUpdateImageUrlsOnConnectListener = () => {
|
||||
`Fetching new image URLs for ${allUsedImages.length} images`
|
||||
);
|
||||
|
||||
allUsedImages.forEach(({ image_name, image_origin }) => {
|
||||
allUsedImages.forEach(({ image_name }) => {
|
||||
dispatch(
|
||||
imageUrlsReceived({
|
||||
imageName: image_name,
|
||||
imageOrigin: image_origin,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
@@ -116,7 +116,6 @@ export const addUserInvokedCanvasListener = () => {
|
||||
// Update the base node with the image name and type
|
||||
baseNode.image = {
|
||||
image_name: baseImageDTO.image_name,
|
||||
image_origin: baseImageDTO.image_origin,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -143,7 +142,6 @@ export const addUserInvokedCanvasListener = () => {
|
||||
// Update the base node with the image name and type
|
||||
baseNode.mask = {
|
||||
image_name: maskImageDTO.image_name,
|
||||
image_origin: maskImageDTO.image_origin,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -160,7 +158,6 @@ export const addUserInvokedCanvasListener = () => {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.image.image_name,
|
||||
imageOrigin: baseNode.image.image_origin,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
@@ -171,7 +168,6 @@ export const addUserInvokedCanvasListener = () => {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.mask.image_name,
|
||||
imageOrigin: baseNode.mask.image_origin,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
List,
|
||||
ListItem,
|
||||
Text,
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||
import { useSelect } from 'downshift';
|
||||
import { isString } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
import { memo, useLayoutEffect, useMemo } from 'react';
|
||||
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
||||
|
||||
export type ItemTooltips = { [key: string]: string };
|
||||
|
||||
export type IAICustomSelectOption = {
|
||||
value: string;
|
||||
label: string;
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAICustomSelectProps = {
|
||||
label?: string;
|
||||
value: string;
|
||||
data: IAICustomSelectOption[] | string[];
|
||||
onChange: (v: string) => void;
|
||||
withCheckIcon?: boolean;
|
||||
formControlProps?: FormControlProps;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
ellipsisPosition?: 'start' | 'end';
|
||||
isDisabled?: boolean;
|
||||
};
|
||||
|
||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
const {
|
||||
label,
|
||||
withCheckIcon,
|
||||
formControlProps,
|
||||
tooltip,
|
||||
tooltipProps,
|
||||
ellipsisPosition = 'end',
|
||||
data,
|
||||
value,
|
||||
onChange,
|
||||
isDisabled = false,
|
||||
} = props;
|
||||
|
||||
const values = useMemo(() => {
|
||||
return data.map<IAICustomSelectOption>((v) => {
|
||||
if (isString(v)) {
|
||||
return { value: v, label: v };
|
||||
}
|
||||
return v;
|
||||
});
|
||||
}, [data]);
|
||||
|
||||
const stringValues = useMemo(() => {
|
||||
return values.map((v) => v.value);
|
||||
}, [values]);
|
||||
|
||||
const valueData = useMemo(() => {
|
||||
return values.find((v) => v.value === value);
|
||||
}, [values, value]);
|
||||
|
||||
const {
|
||||
isOpen,
|
||||
getToggleButtonProps,
|
||||
getLabelProps,
|
||||
getMenuProps,
|
||||
highlightedIndex,
|
||||
getItemProps,
|
||||
} = useSelect({
|
||||
items: stringValues,
|
||||
selectedItem: value,
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
newSelectedItem && onChange(newSelectedItem);
|
||||
},
|
||||
});
|
||||
|
||||
const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
|
||||
// whileElementsMounted: autoUpdate,
|
||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||
});
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (isOpen && refs.reference.current && refs.floating.current) {
|
||||
return autoUpdate(refs.reference.current, refs.floating.current, update);
|
||||
}
|
||||
}, [isOpen, update, refs.floating, refs.reference]);
|
||||
|
||||
const labelTextDirection = useMemo(() => {
|
||||
if (ellipsisPosition === 'start') {
|
||||
return document.dir === 'rtl' ? 'ltr' : 'rtl';
|
||||
}
|
||||
|
||||
return document.dir;
|
||||
}, [ellipsisPosition]);
|
||||
|
||||
return (
|
||||
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
||||
{label && (
|
||||
<FormLabel
|
||||
{...getLabelProps()}
|
||||
onClick={() => {
|
||||
refs.floating.current && refs.floating.current.focus();
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
<Tooltip label={tooltip} {...tooltipProps}>
|
||||
<Flex
|
||||
{...getToggleButtonProps({ ref: refs.reference })}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'pointer',
|
||||
overflow: 'hidden',
|
||||
width: 'full',
|
||||
py: 1,
|
||||
px: 2,
|
||||
gap: 2,
|
||||
justifyContent: 'space-between',
|
||||
pointerEvents: isDisabled ? 'none' : undefined,
|
||||
opacity: isDisabled ? 0.5 : undefined,
|
||||
...getInputOutlineStyles(),
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontWeight: 500,
|
||||
color: 'base.100',
|
||||
whiteSpace: 'nowrap',
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
direction: labelTextDirection,
|
||||
}}
|
||||
>
|
||||
{valueData?.label}
|
||||
</Text>
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
color: 'base.300',
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
<Box {...getMenuProps()}>
|
||||
{isOpen && (
|
||||
<List
|
||||
as={Flex}
|
||||
ref={refs.floating}
|
||||
sx={{
|
||||
...floatingStyles,
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
flexDirection: 'column',
|
||||
zIndex: 2,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
border: '1px',
|
||||
borderColor: 'base.700',
|
||||
shadow: 'dark-lg',
|
||||
py: 2,
|
||||
px: 0,
|
||||
h: 'fit-content',
|
||||
maxH: 64,
|
||||
minW: 48,
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent>
|
||||
{values.map((v, index) => {
|
||||
const isSelected = value === v.value;
|
||||
const isHighlighted = highlightedIndex === index;
|
||||
const fontWeight = isSelected ? 700 : 500;
|
||||
const bg = isHighlighted
|
||||
? 'base.700'
|
||||
: isSelected
|
||||
? 'base.750'
|
||||
: undefined;
|
||||
return (
|
||||
<Tooltip
|
||||
isDisabled={!v.tooltip}
|
||||
key={`${v.value}${index}`}
|
||||
label={v.tooltip}
|
||||
hasArrow
|
||||
placement="right"
|
||||
>
|
||||
<ListItem
|
||||
sx={{
|
||||
bg,
|
||||
py: 1,
|
||||
paddingInlineStart: 3,
|
||||
paddingInlineEnd: 6,
|
||||
cursor: 'pointer',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.15s',
|
||||
}}
|
||||
{...getItemProps({ item: v.value, index })}
|
||||
>
|
||||
{withCheckIcon ? (
|
||||
<Grid gridTemplateColumns="1.25rem auto">
|
||||
<GridItem>
|
||||
{isSelected && <CheckIcon boxSize={2} />}
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.100',
|
||||
fontWeight,
|
||||
}}
|
||||
>
|
||||
{v.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
) : (
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.50',
|
||||
fontWeight,
|
||||
}}
|
||||
>
|
||||
{v.label}
|
||||
</Text>
|
||||
)}
|
||||
</ListItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</OverlayScrollbarsComponent>
|
||||
</List>
|
||||
)}
|
||||
</Box>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICustomSelect);
|
||||
@@ -0,0 +1,76 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { memo } from 'react';
|
||||
|
||||
export type IAISelectDataType = {
|
||||
value: string;
|
||||
label: string;
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
color: 'var(--invokeai-colors-base-300)',
|
||||
fontWeight: 'normal',
|
||||
},
|
||||
input: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-900)',
|
||||
borderWidth: '2px',
|
||||
borderColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
paddingRight: 24,
|
||||
fontWeight: 600,
|
||||
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
||||
'&:focus': {
|
||||
borderColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
dropdown: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
borderColor: 'var(--invokeai-colors-base-700)',
|
||||
},
|
||||
item: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-200)',
|
||||
padding: 6,
|
||||
'&[data-hovered]': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
'&[data-active]': {
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
'&:hover': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
},
|
||||
'&[data-selected]': {
|
||||
color: 'var(--invokeai-colors-base-50)',
|
||||
backgroundColor: 'var(--invokeai-colors-accent-650)',
|
||||
fontWeight: 600,
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIMantineSelect);
|
||||
@@ -0,0 +1,12 @@
|
||||
import { ScrollArea, ScrollAreaProps } from '@mantine/core';
|
||||
|
||||
type IAIScrollArea = ScrollAreaProps;
|
||||
|
||||
export default function IAIScrollArea(props: IAIScrollArea) {
|
||||
const { ...rest } = props;
|
||||
return (
|
||||
<ScrollArea w="100%" {...rest}>
|
||||
{props.children}
|
||||
</ScrollArea>
|
||||
);
|
||||
}
|
||||
@@ -2,7 +2,6 @@ import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
|
||||
import {
|
||||
@@ -25,7 +24,13 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent } from 'react';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
canvasCopiedToClipboard,
|
||||
canvasDownloadedAsImage,
|
||||
canvasMerged,
|
||||
canvasSavedToGallery,
|
||||
} from 'features/canvas/store/actions';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -43,12 +48,6 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||
import {
|
||||
canvasCopiedToClipboard,
|
||||
canvasDownloadedAsImage,
|
||||
canvasMerged,
|
||||
canvasSavedToGallery,
|
||||
} from 'features/canvas/store/actions';
|
||||
|
||||
export const selector = createSelector(
|
||||
[systemSelector, canvasSelector, isStagingSelector],
|
||||
@@ -197,8 +196,8 @@ const IAICanvasToolbar = () => {
|
||||
dispatch(canvasDownloadedAsImage());
|
||||
};
|
||||
|
||||
const handleChangeLayer = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const newLayer = e.target.value as CanvasLayer;
|
||||
const handleChangeLayer = (v: string) => {
|
||||
const newLayer = v as CanvasLayer;
|
||||
dispatch(setLayer(newLayer));
|
||||
if (newLayer === 'mask' && !isMaskEnabled) {
|
||||
dispatch(setIsMaskEnabled(true));
|
||||
@@ -214,13 +213,12 @@ const IAICanvasToolbar = () => {
|
||||
}}
|
||||
>
|
||||
<Box w={24}>
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
|
||||
tooltipProps={{ hasArrow: true, placement: 'top' }}
|
||||
value={layer}
|
||||
validValues={LAYER_NAMES_DICT}
|
||||
data={LAYER_NAMES_DICT}
|
||||
onChange={handleChangeLayer}
|
||||
isDisabled={isStaging}
|
||||
disabled={isStaging}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
|
||||
@@ -866,8 +866,7 @@ export const canvasSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.layerState.objects.forEach((object) => {
|
||||
if (object.kind === 'image') {
|
||||
|
||||
@@ -4,8 +4,8 @@ import { RgbaColor } from 'react-colorful';
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
export const LAYER_NAMES_DICT = [
|
||||
{ key: 'Base', value: 'base' },
|
||||
{ key: 'Mask', value: 'mask' },
|
||||
{ label: 'Base', value: 'base' },
|
||||
{ label: 'Mask', value: 'mask' },
|
||||
];
|
||||
|
||||
export const LAYER_NAMES = ['base', 'mask'] as const;
|
||||
@@ -13,9 +13,9 @@ export const LAYER_NAMES = ['base', 'mask'] as const;
|
||||
export type CanvasLayer = (typeof LAYER_NAMES)[number];
|
||||
|
||||
export const BOUNDING_BOX_SCALES_DICT = [
|
||||
{ key: 'Auto', value: 'auto' },
|
||||
{ key: 'Manual', value: 'manual' },
|
||||
{ key: 'None', value: 'none' },
|
||||
{ label: 'Auto', value: 'auto' },
|
||||
{ label: 'Manual', value: 'manual' },
|
||||
{ label: 'None', value: 'none' },
|
||||
];
|
||||
|
||||
export const BOUNDING_BOX_SCALES = ['none', 'auto', 'manual'] as const;
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetAdded,
|
||||
controlNetRemoved,
|
||||
controlNetToggled,
|
||||
} from '../store/controlNetSlice';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
|
||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { useToggle } from 'react-use';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useToggle } from 'react-use';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
|
||||
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
||||
|
||||
@@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
weight,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
processorNode,
|
||||
@@ -137,45 +139,51 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
</Flex>
|
||||
{isEnabled && (
|
||||
<>
|
||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
w: 'full',
|
||||
h: isExpanded ? 28 : 24,
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetWeight
|
||||
controlNetId={controlNetId}
|
||||
weight={weight}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
<ParamControlNetBeginEnd
|
||||
controlNetId={controlNetId}
|
||||
beginStepPct={beginStepPct}
|
||||
endStepPct={endStepPct}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
</Flex>
|
||||
{!isExpanded && (
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 24,
|
||||
w: 24,
|
||||
aspectRatio: '1/1',
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
w: 'full',
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNet={props.controlNet} />
|
||||
<ParamControlNetWeight
|
||||
controlNetId={controlNetId}
|
||||
weight={weight}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
<ParamControlNetBeginEnd
|
||||
controlNetId={controlNetId}
|
||||
beginStepPct={beginStepPct}
|
||||
endStepPct={endStepPct}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{!isExpanded && (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 24,
|
||||
w: 24,
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNet={props.controlNet} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<ParamControlNetControlMode
|
||||
controlNetId={controlNetId}
|
||||
controlMode={controlMode}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<Box mt={2}>
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
ControlModes,
|
||||
controlNetControlModeChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ParamControlNetControlModeProps = {
|
||||
controlNetId: string;
|
||||
controlMode: string;
|
||||
};
|
||||
|
||||
const CONTROL_MODE_DATA = [
|
||||
{ label: 'Balanced', value: 'balanced' },
|
||||
{ label: 'Prompt', value: 'more_prompt' },
|
||||
{ label: 'Control', value: 'more_control' },
|
||||
{ label: 'Mega Control', value: 'unbalanced' },
|
||||
];
|
||||
|
||||
export default function ParamControlNetControlMode(
|
||||
props: ParamControlNetControlModeProps
|
||||
) {
|
||||
const { controlNetId, controlMode = false } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleControlModeChange = useCallback(
|
||||
(controlMode: ControlModes) => {
|
||||
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.controlNetControlMode')}
|
||||
data={CONTROL_MODE_DATA}
|
||||
value={String(controlMode)}
|
||||
onChange={handleControlModeChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
@@ -12,7 +11,7 @@ import {
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
@@ -20,17 +19,18 @@ type ParamControlNetModelProps = {
|
||||
};
|
||||
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
return map(CONTROLNET_MODELS, (m) => ({
|
||||
key: m.label,
|
||||
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
label: m.label,
|
||||
value: m.type,
|
||||
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
|
||||
});
|
||||
})).filter(
|
||||
(d) =>
|
||||
!config.sd.disabledControlNetModels.includes(
|
||||
d.value as ControlNetModelName
|
||||
)
|
||||
);
|
||||
|
||||
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
// value: m.type,
|
||||
// label: m.label,
|
||||
// tooltip: m.type,
|
||||
// }));
|
||||
return controlNetModels;
|
||||
});
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
@@ -39,47 +39,23 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
(val: string | null) => {
|
||||
// TODO: do not cast
|
||||
const model = e.target.value as ControlNetModelName;
|
||||
const model = val as ControlNetModelName;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
// const handleModelChanged = useCallback(
|
||||
// (val: string | null | undefined) => {
|
||||
// // TODO: do not cast
|
||||
// const model = val as ControlNetModelName;
|
||||
// dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
// },
|
||||
// [controlNetId, dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
tooltip={model}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
validValues={controlNetModels}
|
||||
<IAIMantineSelect
|
||||
data={controlNetModels}
|
||||
value={model}
|
||||
onChange={handleModelChanged}
|
||||
isDisabled={!isReady}
|
||||
// ellipsisPosition="start"
|
||||
// withCheckIcon
|
||||
disabled={!isReady}
|
||||
tooltip={model}
|
||||
/>
|
||||
);
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// tooltip={model}
|
||||
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
// data={DATA}
|
||||
// value={model}
|
||||
// onChange={handleModelChanged}
|
||||
// isDisabled={!isReady}
|
||||
// ellipsisPosition="start"
|
||||
// withCheckIcon
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetModel);
|
||||
|
||||
@@ -1,64 +1,55 @@
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import {
|
||||
ControlNetProcessorNode,
|
||||
ControlNetProcessorType,
|
||||
} from '../../store/types';
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
|
||||
type ParamControlNetProcessorSelectProps = {
|
||||
controlNetId: string;
|
||||
processorNode: ControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
|
||||
value: p.type,
|
||||
key: p.label,
|
||||
})).sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
|
||||
);
|
||||
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
return map(CONTROLNET_PROCESSORS, (p) => ({
|
||||
value: p.type,
|
||||
key: p.label,
|
||||
}))
|
||||
.sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none'
|
||||
? -1
|
||||
: b.value === 'none'
|
||||
? 1
|
||||
: a.key.localeCompare(b.key)
|
||||
const selector = createSelector(
|
||||
configSelector,
|
||||
(config) => {
|
||||
const controlNetProcessors: IAISelectDataType[] = map(
|
||||
CONTROLNET_PROCESSORS,
|
||||
(p) => ({
|
||||
value: p.type,
|
||||
label: p.label,
|
||||
})
|
||||
)
|
||||
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
|
||||
});
|
||||
.sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none'
|
||||
? -1
|
||||
: b.value === 'none'
|
||||
? 1
|
||||
: a.label.localeCompare(b.label)
|
||||
)
|
||||
.filter(
|
||||
(d) =>
|
||||
!config.sd.disabledControlNetProcessors.includes(
|
||||
d.value as ControlNetProcessorType
|
||||
)
|
||||
);
|
||||
|
||||
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
|
||||
// CONTROLNET_PROCESSORS,
|
||||
// (p) => ({
|
||||
// value: p.type,
|
||||
// label: p.label,
|
||||
// tooltip: p.description,
|
||||
// })
|
||||
// ).sort((a, b) =>
|
||||
// // sort 'none' to the top
|
||||
// a.value === 'none'
|
||||
// ? -1
|
||||
// : b.value === 'none'
|
||||
// ? 1
|
||||
// : a.label.localeCompare(b.label)
|
||||
// );
|
||||
return controlNetProcessors;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamControlNetProcessorSelect = (
|
||||
props: ParamControlNetProcessorSelectProps
|
||||
@@ -69,47 +60,26 @@ const ParamControlNetProcessorSelect = (
|
||||
const controlNetProcessors = useAppSelector(selector);
|
||||
|
||||
const handleProcessorTypeChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
(v: string | null) => {
|
||||
dispatch(
|
||||
controlNetProcessorTypeChanged({
|
||||
controlNetId,
|
||||
processorType: e.target.value as ControlNetProcessorType,
|
||||
processorType: v as ControlNetProcessorType,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
// const handleProcessorTypeChanged = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// dispatch(
|
||||
// controlNetProcessorTypeChanged({
|
||||
// controlNetId,
|
||||
// processorType: v as ControlNetProcessorType,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// [controlNetId, dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label="Processor"
|
||||
value={processorNode.type ?? 'canny_image_processor'}
|
||||
validValues={controlNetProcessors}
|
||||
data={controlNetProcessors}
|
||||
onChange={handleProcessorTypeChanged}
|
||||
isDisabled={!isReady}
|
||||
disabled={!isReady}
|
||||
/>
|
||||
);
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label="Processor"
|
||||
// value={processorNode.type ?? 'canny_image_processor'}
|
||||
// data={CONTROLNET_PROCESSOR_TYPES}
|
||||
// onChange={handleProcessorTypeChanged}
|
||||
// withCheckIcon
|
||||
// isDisabled={!isReady}
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetProcessorSelect);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
|
||||
@@ -23,7 +22,7 @@ type ControlNetProcessorsDict = Record<
|
||||
*
|
||||
* TODO: Generate from the OpenAPI schema
|
||||
*/
|
||||
export const CONTROLNET_PROCESSORS = {
|
||||
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
none: {
|
||||
type: 'none',
|
||||
label: 'none',
|
||||
@@ -174,6 +173,8 @@ export const CONTROLNET_PROCESSORS = {
|
||||
},
|
||||
};
|
||||
|
||||
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
||||
|
||||
type ControlNetModel = {
|
||||
type: string;
|
||||
label: string;
|
||||
@@ -181,7 +182,7 @@ type ControlNetModel = {
|
||||
defaultProcessor?: ControlNetProcessorType;
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS = {
|
||||
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
||||
'lllyasviel/control_v11p_sd15_canny': {
|
||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||
label: 'Canny',
|
||||
@@ -190,6 +191,7 @@ export const CONTROLNET_MODELS = {
|
||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||
label: 'Inpaint',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||
@@ -209,6 +211,7 @@ export const CONTROLNET_MODELS = {
|
||||
'lllyasviel/control_v11p_sd15_seg': {
|
||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||
label: 'Segmentation',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_lineart': {
|
||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||
@@ -223,6 +226,7 @@ export const CONTROLNET_MODELS = {
|
||||
'lllyasviel/control_v11p_sd15_scribble': {
|
||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||
label: 'Scribble',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_softedge': {
|
||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||
@@ -242,10 +246,12 @@ export const CONTROLNET_MODELS = {
|
||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||
label: 'Tile (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||
label: 'Pix2Pix (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||
|
||||
@@ -1,36 +1,27 @@
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
||||
import { isAnySessionRejected } from 'services/thunks/session';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
ControlNetModelName,
|
||||
} from './constants';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { isAnySessionRejected } from 'services/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
controlImage: null,
|
||||
processedControlImage: null,
|
||||
processorType: 'canny_image_processor',
|
||||
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
|
||||
.default as RequiredCannyImageProcessorInvocation,
|
||||
shouldAutoConfig: true,
|
||||
};
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
| 'more_prompt'
|
||||
| 'more_control'
|
||||
| 'unbalanced';
|
||||
|
||||
export type ControlNetConfig = {
|
||||
controlNetId: string;
|
||||
@@ -39,6 +30,7 @@ export type ControlNetConfig = {
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlMode: ControlModes;
|
||||
controlImage: ImageDTO | null;
|
||||
processedControlImage: ImageDTO | null;
|
||||
processorType: ControlNetProcessorType;
|
||||
@@ -46,6 +38,21 @@ export type ControlNetConfig = {
|
||||
shouldAutoConfig: boolean;
|
||||
};
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
controlMode: 'balanced',
|
||||
controlImage: null,
|
||||
processedControlImage: null,
|
||||
processorType: 'canny_image_processor',
|
||||
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
|
||||
.default as RequiredCannyImageProcessorInvocation,
|
||||
shouldAutoConfig: true,
|
||||
};
|
||||
|
||||
export type ControlNetState = {
|
||||
controlNets: Record<string, ControlNetConfig>;
|
||||
isEnabled: boolean;
|
||||
@@ -147,11 +154,13 @@ export const controlNetSlice = createSlice({
|
||||
state.controlNets[controlNetId].processedControlImage = null;
|
||||
|
||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
||||
const processorType =
|
||||
CONTROLNET_MODELS[model as keyof typeof CONTROLNET_MODELS]
|
||||
.defaultProcessor;
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
processorType
|
||||
processorType as keyof typeof CONTROLNET_PROCESSORS
|
||||
].default as RequiredControlNetProcessorNode;
|
||||
} else {
|
||||
state.controlNets[controlNetId].processorType = 'none';
|
||||
@@ -181,6 +190,13 @@ export const controlNetSlice = createSlice({
|
||||
const { controlNetId, endStepPct } = action.payload;
|
||||
state.controlNets[controlNetId].endStepPct = endStepPct;
|
||||
},
|
||||
controlNetControlModeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
|
||||
) => {
|
||||
const { controlNetId, controlMode } = action.payload;
|
||||
state.controlNets[controlNetId].controlMode = controlMode;
|
||||
},
|
||||
controlNetProcessorParamsChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -210,7 +226,7 @@ export const controlNetSlice = createSlice({
|
||||
state.controlNets[controlNetId].processedControlImage = null;
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
processorType
|
||||
processorType as keyof typeof CONTROLNET_PROCESSORS
|
||||
].default as RequiredControlNetProcessorNode;
|
||||
state.controlNets[controlNetId].shouldAutoConfig = false;
|
||||
},
|
||||
@@ -227,12 +243,14 @@ export const controlNetSlice = createSlice({
|
||||
if (newShouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
const processorType =
|
||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
||||
.defaultProcessor;
|
||||
CONTROLNET_MODELS[
|
||||
state.controlNets[controlNetId]
|
||||
.model as keyof typeof CONTROLNET_MODELS
|
||||
].defaultProcessor;
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
processorType
|
||||
processorType as keyof typeof CONTROLNET_PROCESSORS
|
||||
].default as RequiredControlNetProcessorNode;
|
||||
} else {
|
||||
state.controlNets[controlNetId].processorType = 'none';
|
||||
@@ -271,8 +289,7 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage?.image_name === image_name) {
|
||||
@@ -286,11 +303,11 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
builder.addCase(appSocketInvocationError, (state) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
},
|
||||
@@ -308,6 +325,7 @@ export const {
|
||||
controlNetWeightChanged,
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
controlNetControlModeChanged,
|
||||
controlNetProcessorParamsChanged,
|
||||
controlNetProcessorTypeChanged,
|
||||
controlNetReset,
|
||||
|
||||
@@ -9,15 +9,15 @@ import {
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
@@ -324,7 +324,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
borderRadius: 'base',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
w: 'max-content',
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
|
||||
@@ -59,8 +59,7 @@ export const gallerySlice = createSlice({
|
||||
}
|
||||
});
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
if (state.selectedImage?.image_name === image_name) {
|
||||
state.selectedImage.image_url = image_url;
|
||||
|
||||
@@ -86,8 +86,7 @@ const imagesSlice = createSlice({
|
||||
imagesAdapter.removeOne(state, imageName);
|
||||
});
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
imagesAdapter.updateOne(state, {
|
||||
id: image_name,
|
||||
|
||||
@@ -103,8 +103,7 @@ const nodesSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.nodes.forEach((node) => {
|
||||
forEach(node.data.inputs, (input) => {
|
||||
|
||||
@@ -45,6 +45,7 @@ export const addControlNetToLinearGraph = (
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
@@ -60,23 +61,22 @@ export const addControlNetToLinearGraph = (
|
||||
type: 'controlnet',
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
const { image_name, image_origin } = processedControlImage;
|
||||
const { image_name } = processedControlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_origin,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
const { image_name, image_origin } = controlImage;
|
||||
const { image_name } = controlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_origin,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
|
||||
@@ -354,7 +354,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.image_name,
|
||||
image_origin: initialImage.image_origin,
|
||||
},
|
||||
is_intermediate: true,
|
||||
height,
|
||||
@@ -392,7 +391,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||
image_name: initialImage.image_name,
|
||||
image_origin: initialImage.image_origin,
|
||||
});
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
|
||||
@@ -57,8 +57,7 @@ export const buildImg2ImgNode = (
|
||||
}
|
||||
|
||||
imageToImageNode.image = {
|
||||
image_name: initialImage.name,
|
||||
image_origin: initialImage.type,
|
||||
image_name: initialImage.image_name,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setInfillMethod } from 'features/parameters/store/generationSlice';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
@@ -30,17 +30,17 @@ const ParamInfillMethod = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(setInfillMethod(e.target.value));
|
||||
(v: string) => {
|
||||
dispatch(setInfillMethod(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.infillMethod')}
|
||||
value={infillMethod}
|
||||
validValues={infillMethods}
|
||||
data={infillMethods}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
BoundingBoxScale,
|
||||
BOUNDING_BOX_SCALES_DICT,
|
||||
BoundingBoxScale,
|
||||
} from 'features/canvas/store/canvasTypes';
|
||||
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
@@ -30,16 +30,14 @@ const ParamScaleBeforeProcessing = () => {
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangeBoundingBoxScaleMethod = (
|
||||
e: ChangeEvent<HTMLSelectElement>
|
||||
) => {
|
||||
dispatch(setBoundingBoxScaleMethod(e.target.value as BoundingBoxScale));
|
||||
const handleChangeBoundingBoxScaleMethod = (v: string) => {
|
||||
dispatch(setBoundingBoxScaleMethod(v as BoundingBoxScale));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scaleBeforeProcessing')}
|
||||
validValues={BOUNDING_BOX_SCALES_DICT}
|
||||
data={BOUNDING_BOX_SCALES_DICT}
|
||||
value={boundingBoxScale}
|
||||
onChange={handleChangeBoundingBoxScaleMethod}
|
||||
/>
|
||||
|
||||
@@ -2,23 +2,20 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
|
||||
// but we need to wait for the next release before removing this special handling.
|
||||
const allSchedulers = ui.schedulers
|
||||
.filter((scheduler) => {
|
||||
return !['dpmpp_2s'].includes(scheduler);
|
||||
})
|
||||
const allSchedulers: string[] = ui.schedulers
|
||||
.slice()
|
||||
.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
return {
|
||||
@@ -36,39 +33,23 @@ const ParamScheduler = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(setScheduler(e.target.value as Scheduler));
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const handleChange = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// if (!v) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(setScheduler(v as Scheduler));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
validValues={allSchedulers}
|
||||
data={allSchedulers}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label={t('parameters.scheduler')}
|
||||
// value={scheduler}
|
||||
// data={allSchedulers}
|
||||
// onChange={handleChange}
|
||||
// withCheckIcon
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamScheduler);
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { FACETOOL_TYPES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
FacetoolType,
|
||||
setFacetoolType,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function FaceRestoreType() {
|
||||
@@ -17,13 +16,13 @@ export default function FaceRestoreType() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChangeFacetoolType = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setFacetoolType(e.target.value as FacetoolType));
|
||||
const handleChangeFacetoolType = (v: string) =>
|
||||
dispatch(setFacetoolType(v as FacetoolType));
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.type')}
|
||||
validValues={FACETOOL_TYPES.concat()}
|
||||
data={FACETOOL_TYPES.concat()}
|
||||
value={facetoolType}
|
||||
onChange={handleChangeFacetoolType}
|
||||
/>
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { UPSCALING_LEVELS } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
setUpscalingLevel,
|
||||
UpscalingLevel,
|
||||
setUpscalingLevel,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function UpscaleScale() {
|
||||
@@ -21,16 +20,16 @@ export default function UpscaleScale() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
|
||||
const handleChangeLevel = (v: string) =>
|
||||
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
isDisabled={!isESRGANAvailable}
|
||||
<IAIMantineSelect
|
||||
disabled={!isESRGANAvailable}
|
||||
label={t('parameters.scale')}
|
||||
value={upscalingLevel}
|
||||
value={String(upscalingLevel)}
|
||||
onChange={handleChangeLevel}
|
||||
validValues={UPSCALING_LEVELS}
|
||||
data={UPSCALING_LEVELS}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { ImageDTO, ResourceOrigin } from 'services/api';
|
||||
|
||||
export type ImageNameAndOrigin = {
|
||||
image_name: string;
|
||||
image_origin: ResourceOrigin;
|
||||
};
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
|
||||
@@ -234,8 +234,7 @@ export const generationSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
if (state.initialImage?.image_name === image_name) {
|
||||
state.initialImage.image_url = image_url;
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
@@ -19,18 +18,11 @@ const selector = createSelector(
|
||||
const selectedModel = selectModelsById(state, generation.model);
|
||||
|
||||
const modelData = selectModelsAll(state)
|
||||
.map((m) => ({
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
key: m.name,
|
||||
label: m.name,
|
||||
}))
|
||||
.sort((a, b) => a.key.localeCompare(b.key));
|
||||
// const modelData = selectModelsAll(state)
|
||||
// .map<IAICustomSelectOption>((m) => ({
|
||||
// value: m.name,
|
||||
// label: m.name,
|
||||
// tooltip: m.description,
|
||||
// }))
|
||||
// .sort((a, b) => a.label.localeCompare(b.label));
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
return {
|
||||
selectedModel,
|
||||
modelData,
|
||||
@@ -48,43 +40,25 @@ const ModelSelect = () => {
|
||||
const { t } = useTranslation();
|
||||
const { selectedModel, modelData } = useAppSelector(selector);
|
||||
const handleChangeModel = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(modelSelected(e.target.value));
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const handleChangeModel = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// if (!v) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(modelSelected(v));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAISelect
|
||||
label={t('modelManager.model')}
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
validValues={modelData}
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.name ?? ''}
|
||||
placeholder="Pick one"
|
||||
data={modelData}
|
||||
onChange={handleChangeModel}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
/>
|
||||
);
|
||||
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label={t('modelManager.model')}
|
||||
// tooltip={selectedModel?.description}
|
||||
// data={modelData}
|
||||
// value={selectedModel?.name ?? ''}
|
||||
// onChange={handleChangeModel}
|
||||
// withCheckIcon={true}
|
||||
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ModelSelect);
|
||||
|
||||
@@ -13,19 +13,21 @@ import {
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { VALID_LOG_LEVELS } from 'app/logging/useLogger';
|
||||
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import {
|
||||
SystemState,
|
||||
consoleLogLevelChanged,
|
||||
setEnableImageDebugging,
|
||||
setShouldConfirmOnDelete,
|
||||
setShouldDisplayGuides,
|
||||
shouldAntialiasProgressImageChanged,
|
||||
shouldLogToConsoleChanged,
|
||||
SystemState,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
@@ -37,15 +39,13 @@ import { UIState } from 'features/ui/store/uiTypes';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import {
|
||||
ChangeEvent,
|
||||
cloneElement,
|
||||
ReactElement,
|
||||
cloneElement,
|
||||
useCallback,
|
||||
useEffect,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { VALID_LOG_LEVELS } from 'app/logging/useLogger';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
|
||||
import SettingsSchedulers from './SettingsSchedulers';
|
||||
|
||||
const selector = createSelector(
|
||||
@@ -157,8 +157,8 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
|
||||
}, [onSettingsModalClose, onRefreshModalOpen]);
|
||||
|
||||
const handleLogLevelChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(consoleLogLevelChanged(e.target.value as LogLevelName));
|
||||
(v: string) => {
|
||||
dispatch(consoleLogLevelChanged(v as LogLevelName));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -255,14 +255,12 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
|
||||
isChecked={shouldLogToConsole}
|
||||
onChange={handleLogToConsoleChanged}
|
||||
/>
|
||||
<IAISelect
|
||||
horizontal
|
||||
spaceEvenly
|
||||
isDisabled={!shouldLogToConsole}
|
||||
<IAIMantineSelect
|
||||
disabled={!shouldLogToConsole}
|
||||
label={t('settings.consoleLogLevel')}
|
||||
onChange={handleLogLevelChanged}
|
||||
value={consoleLogLevel}
|
||||
validValues={VALID_LOG_LEVELS.concat()}
|
||||
data={VALID_LOG_LEVELS.concat()}
|
||||
/>
|
||||
<IAISwitch
|
||||
label={t('settings.enableImageDebugging')}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
Box,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItemOption,
|
||||
@@ -13,7 +12,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { isArray } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import InvokeAILogoComponent from 'features/system/components/InvokeAILogoComponent';
|
||||
import OverlayScrollable from './common/OverlayScrollable';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { setShouldShowParametersPanel } from 'features/ui/store/uiSlice';
|
||||
import ResizableDrawer from './common/ResizableDrawer/ResizableDrawer';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import PinParametersPanelButton from './PinParametersPanelButton';
|
||||
import TextToImageTabParameters from './tabs/TextToImage/TextToImageTabParameters';
|
||||
import OverlayScrollable from './common/OverlayScrollable';
|
||||
import ResizableDrawer from './common/ResizableDrawer/ResizableDrawer';
|
||||
import ImageToImageTabParameters from './tabs/ImageToImage/ImageToImageTabParameters';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import TextToImageTabParameters from './tabs/TextToImage/TextToImageTabParameters';
|
||||
import UnifiedCanvasParameters from './tabs/UnifiedCanvas/UnifiedCanvasParameters';
|
||||
|
||||
const selector = createSelector(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import OverlayScrollable from './common/OverlayScrollable';
|
||||
import PinParametersPanelButton from './PinParametersPanelButton';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { uiSelector } from '../store/uiSelectors';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import PinParametersPanelButton from './PinParametersPanelButton';
|
||||
import OverlayScrollable from './common/OverlayScrollable';
|
||||
|
||||
const selector = createSelector(uiSelector, (ui) => {
|
||||
const { shouldPinParametersPanel, shouldShowParametersPanel } = ui;
|
||||
@@ -35,19 +35,27 @@ const ParametersPinnedWrapper = (props: ParametersPinnedWrapperProps) => {
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<OverlayScrollable>
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
position: 'absolute',
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</Flex>
|
||||
</OverlayScrollable>
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
position: 'absolute',
|
||||
}}
|
||||
>
|
||||
<OverlayScrollable>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</Flex>
|
||||
</OverlayScrollable>
|
||||
</Flex>
|
||||
|
||||
<PinParametersPanelButton
|
||||
sx={{ position: 'absolute', top: 0, insetInlineEnd: 0 }}
|
||||
/>
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
|
||||
const OverlayScrollable = (props: PropsWithChildren) => {
|
||||
return (
|
||||
<OverlayScrollbarsComponent
|
||||
@@ -20,5 +19,4 @@ const OverlayScrollable = (props: PropsWithChildren) => {
|
||||
</OverlayScrollbarsComponent>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(OverlayScrollable);
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import ResizeHandle from '../ResizeHandle';
|
||||
import ImageToImageTabParameters from './ImageToImageTabParameters';
|
||||
import TextToImageTabMain from '../TextToImage/TextToImageTabMain';
|
||||
import { ImperativePanelGroupHandle } from 'react-resizable-panels';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import {
|
||||
ImperativePanelGroupHandle,
|
||||
Panel,
|
||||
PanelGroup,
|
||||
} from 'react-resizable-panels';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import ResizeHandle from '../ResizeHandle';
|
||||
import TextToImageTabMain from '../TextToImage/TextToImageTabMain';
|
||||
import ImageToImageTabParameters from './ImageToImageTabParameters';
|
||||
|
||||
const ImageToImageTab = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import TextToImageTabMain from './TextToImageTabMain';
|
||||
import TextToImageTabParameters from './TextToImageTabParameters';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
return (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
} from 'features/canvas/store/canvasTypes';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -51,22 +50,22 @@ export default function UnifiedCanvasLayerSelect() {
|
||||
[layer]
|
||||
);
|
||||
|
||||
const handleChangeLayer = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const newLayer = e.target.value as CanvasLayer;
|
||||
const handleChangeLayer = (v: string) => {
|
||||
const newLayer = v as CanvasLayer;
|
||||
dispatch(setLayer(newLayer));
|
||||
if (newLayer === 'mask' && !isMaskEnabled) {
|
||||
dispatch(setIsMaskEnabled(true));
|
||||
}
|
||||
};
|
||||
return (
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
|
||||
aria-label={`${t('unifiedCanvas.layer')} (Q)`}
|
||||
tooltipProps={{ hasArrow: true, placement: 'top' }}
|
||||
value={layer}
|
||||
validValues={LAYER_NAMES_DICT}
|
||||
data={LAYER_NAMES_DICT}
|
||||
onChange={handleChangeLayer}
|
||||
isDisabled={isStaging}
|
||||
disabled={isStaging}
|
||||
w="full"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Flex } from '@chakra-ui/react';
|
||||
|
||||
import IAICanvasRedoButton from 'features/canvas/components/IAICanvasToolbar/IAICanvasRedoButton';
|
||||
import IAICanvasUndoButton from 'features/canvas/components/IAICanvasToolbar/IAICanvasUndoButton';
|
||||
import UnifiedCanvasSettings from './UnifiedCanvasToolSettings/UnifiedCanvasSettings';
|
||||
import UnifiedCanvasCopyToClipboard from './UnifiedCanvasToolbar/UnifiedCanvasCopyToClipboard';
|
||||
import UnifiedCanvasDownloadImage from './UnifiedCanvasToolbar/UnifiedCanvasDownloadImage';
|
||||
import UnifiedCanvasFileUploader from './UnifiedCanvasToolbar/UnifiedCanvasFileUploader';
|
||||
@@ -13,11 +14,10 @@ import UnifiedCanvasResetCanvas from './UnifiedCanvasToolbar/UnifiedCanvasResetC
|
||||
import UnifiedCanvasResetView from './UnifiedCanvasToolbar/UnifiedCanvasResetView';
|
||||
import UnifiedCanvasSaveToGallery from './UnifiedCanvasToolbar/UnifiedCanvasSaveToGallery';
|
||||
import UnifiedCanvasToolSelect from './UnifiedCanvasToolbar/UnifiedCanvasToolSelect';
|
||||
import UnifiedCanvasSettings from './UnifiedCanvasToolSettings/UnifiedCanvasSettings';
|
||||
|
||||
const UnifiedCanvasToolbarBeta = () => {
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={2}>
|
||||
<Flex flexDirection="column" rowGap={2} width="min-content">
|
||||
<UnifiedCanvasLayerSelect />
|
||||
<UnifiedCanvasToolSelect />
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import UnifiedCanvasContent from './UnifiedCanvasContent';
|
||||
import UnifiedCanvasParameters from './UnifiedCanvasParameters';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
|
||||
const UnifiedCanvasTab = () => {
|
||||
return (
|
||||
|
||||
23
invokeai/frontend/web/src/mantine-theme/theme.ts
Normal file
23
invokeai/frontend/web/src/mantine-theme/theme.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { MantineThemeOverride } from '@mantine/core';
|
||||
|
||||
export const mantineTheme: MantineThemeOverride = {
|
||||
colorScheme: 'dark',
|
||||
fontFamily: `'InterVariable', sans-serif`,
|
||||
components: {
|
||||
ScrollArea: {
|
||||
defaultProps: {
|
||||
scrollbarSize: 10,
|
||||
},
|
||||
styles: {
|
||||
scrollbar: {
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
|
||||
},
|
||||
},
|
||||
thumb: {
|
||||
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -24,9 +24,11 @@ export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
|
||||
export type { DivideInvocation } from './models/DivideInvocation';
|
||||
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
|
||||
export type { Edge } from './models/Edge';
|
||||
export type { EdgeConnection } from './models/EdgeConnection';
|
||||
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
|
||||
export type { FloatLinearRangeInvocation } from './models/FloatLinearRangeInvocation';
|
||||
export type { FloatOutput } from './models/FloatOutput';
|
||||
export type { Graph } from './models/Graph';
|
||||
export type { GraphExecutionState } from './models/GraphExecutionState';
|
||||
@@ -85,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
|
||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
||||
export type { PromptOutput } from './models/PromptOutput';
|
||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
||||
@@ -95,6 +98,7 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||
export type { TextToImageInvocation } from './models/TextToImageInvocation';
|
||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||
|
||||
@@ -24,4 +24,3 @@ export type AddInvocation = {
|
||||
*/
|
||||
'b'?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -5,4 +5,3 @@
|
||||
export type Body_upload_image = {
|
||||
file: Blob;
|
||||
};
|
||||
|
||||
|
||||
@@ -30,4 +30,3 @@ export type CannyImageProcessorInvocation = {
|
||||
*/
|
||||
high_threshold?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -29,4 +29,3 @@ export type CkptModelInfo = {
|
||||
*/
|
||||
height?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -24,4 +24,3 @@ export type CollectInvocation = {
|
||||
*/
|
||||
collection?: Array<any>;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,4 +12,3 @@ export type CollectInvocationOutput = {
|
||||
*/
|
||||
collection: Array<any>;
|
||||
};
|
||||
|
||||
|
||||
@@ -20,4 +20,3 @@ export type ColorField = {
|
||||
*/
|
||||
'a': number;
|
||||
};
|
||||
|
||||
|
||||
@@ -24,4 +24,3 @@ export type CompelInvocation = {
|
||||
*/
|
||||
model?: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -14,4 +14,3 @@ export type CompelOutput = {
|
||||
*/
|
||||
conditioning?: ConditioningField;
|
||||
};
|
||||
|
||||
|
||||
@@ -8,4 +8,3 @@ export type ConditioningField = {
|
||||
*/
|
||||
conditioning_name: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -42,4 +42,3 @@ export type ContentShuffleImageProcessorInvocation = {
|
||||
*/
|
||||
'f'?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ export type ControlField = {
|
||||
/**
|
||||
* The weight given to the ControlNet
|
||||
*/
|
||||
control_weight: number;
|
||||
control_weight: (number | Array<number>);
|
||||
/**
|
||||
* When the ControlNet is first applied (% of total steps)
|
||||
*/
|
||||
@@ -25,5 +25,8 @@ export type ControlField = {
|
||||
* When the ControlNet is last applied (% of total steps)
|
||||
*/
|
||||
end_step_percent: number;
|
||||
/**
|
||||
* The contorl mode to use
|
||||
*/
|
||||
control_mode?: 'balanced' | 'more_prompt' | 'more_control' | 'unbalanced';
|
||||
};
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ export type ControlNetInvocation = {
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The ControlNet model to use
|
||||
* control model used
|
||||
*/
|
||||
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||
/**
|
||||
* The weight given to the ControlNet
|
||||
*/
|
||||
control_weight?: number;
|
||||
control_weight?: (number | Array<number>);
|
||||
/**
|
||||
* When the ControlNet is first applied (% of total steps)
|
||||
*/
|
||||
@@ -37,5 +37,8 @@ export type ControlNetInvocation = {
|
||||
* When the ControlNet is last applied (% of total steps)
|
||||
*/
|
||||
end_step_percent?: number;
|
||||
/**
|
||||
* The control mode used
|
||||
*/
|
||||
control_mode?: 'balanced' | 'more_prompt' | 'more_control' | 'unbalanced';
|
||||
};
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ import type { ControlField } from './ControlField';
|
||||
export type ControlOutput = {
|
||||
type?: 'control_output';
|
||||
/**
|
||||
* The output control image
|
||||
* The control info
|
||||
*/
|
||||
control?: ControlField;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,4 +15,3 @@ export type CreateModelRequest = {
|
||||
*/
|
||||
info: (CkptModelInfo | DiffusersModelInfo);
|
||||
};
|
||||
|
||||
|
||||
@@ -26,4 +26,3 @@ export type CvInpaintInvocation = {
|
||||
*/
|
||||
mask?: ImageField;
|
||||
};
|
||||
|
||||
|
||||
@@ -23,4 +23,3 @@ export type DiffusersModelInfo = {
|
||||
*/
|
||||
path?: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -24,4 +24,3 @@ export type DivideInvocation = {
|
||||
*/
|
||||
'b'?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
|
||||
*/
|
||||
export type DynamicPromptInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'dynamic_prompt';
|
||||
/**
|
||||
* The prompt to parse with dynamicprompts
|
||||
*/
|
||||
prompt: string;
|
||||
/**
|
||||
* The number of prompts to generate
|
||||
*/
|
||||
max_prompts?: number;
|
||||
/**
|
||||
* Whether to use the combinatorial generator
|
||||
*/
|
||||
combinatorial?: boolean;
|
||||
};
|
||||
@@ -14,4 +14,3 @@ export type Edge = {
|
||||
*/
|
||||
destination: EdgeConnection;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,4 +12,3 @@ export type EdgeConnection = {
|
||||
*/
|
||||
field: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,4 +12,3 @@ export type FloatCollectionOutput = {
|
||||
*/
|
||||
collection?: Array<number>;
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Creates a range
|
||||
*/
|
||||
export type FloatLinearRangeInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'float_range';
|
||||
/**
|
||||
* The first value of the range
|
||||
*/
|
||||
start?: number;
|
||||
/**
|
||||
* The last value of the range
|
||||
*/
|
||||
stop?: number;
|
||||
/**
|
||||
* number of values to interpolate over (including start and stop)
|
||||
*/
|
||||
steps?: number;
|
||||
};
|
||||
@@ -12,4 +12,3 @@ export type FloatOutput = {
|
||||
*/
|
||||
param?: number;
|
||||
};
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma
|
||||
import type { ControlNetInvocation } from './ControlNetInvocation';
|
||||
import type { CvInpaintInvocation } from './CvInpaintInvocation';
|
||||
import type { DivideInvocation } from './DivideInvocation';
|
||||
import type { DynamicPromptInvocation } from './DynamicPromptInvocation';
|
||||
import type { Edge } from './Edge';
|
||||
import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation';
|
||||
import type { GraphInvocation } from './GraphInvocation';
|
||||
import type { HedImageProcessorInvocation } from './HedImageProcessorInvocation';
|
||||
import type { ImageBlurInvocation } from './ImageBlurInvocation';
|
||||
@@ -55,6 +57,7 @@ import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
||||
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
||||
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
||||
import type { ShowImageInvocation } from './ShowImageInvocation';
|
||||
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
||||
import type { SubtractInvocation } from './SubtractInvocation';
|
||||
import type { TextToImageInvocation } from './TextToImageInvocation';
|
||||
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
|
||||
@@ -69,10 +72,9 @@ export type Graph = {
|
||||
/**
|
||||
* The nodes in this graph
|
||||
*/
|
||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||
nodes?: Record<string, (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | DynamicPromptInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation)>;
|
||||
/**
|
||||
* The connections between nodes and their fields in this graph
|
||||
*/
|
||||
edges?: Array<Edge>;
|
||||
};
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput';
|
||||
import type { LatentsOutput } from './LatentsOutput';
|
||||
import type { MaskOutput } from './MaskOutput';
|
||||
import type { NoiseOutput } from './NoiseOutput';
|
||||
import type { PromptCollectionOutput } from './PromptCollectionOutput';
|
||||
import type { PromptOutput } from './PromptOutput';
|
||||
|
||||
/**
|
||||
@@ -45,7 +46,7 @@ export type GraphExecutionState = {
|
||||
/**
|
||||
* The results of node executions
|
||||
*/
|
||||
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
results: Record<string, (IntCollectionOutput | FloatCollectionOutput | CompelOutput | ImageOutput | MaskOutput | ControlOutput | LatentsOutput | NoiseOutput | IntOutput | FloatOutput | PromptOutput | PromptCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
/**
|
||||
* Errors raised when executing nodes
|
||||
*/
|
||||
@@ -59,4 +60,3 @@ export type GraphExecutionState = {
|
||||
*/
|
||||
source_prepared_mapping: Record<string, Array<string>>;
|
||||
};
|
||||
|
||||
|
||||
@@ -22,4 +22,3 @@ export type GraphInvocation = {
|
||||
*/
|
||||
graph?: Graph;
|
||||
};
|
||||
|
||||
|
||||
@@ -8,4 +8,3 @@
|
||||
export type GraphInvocationOutput = {
|
||||
type: 'graph_output';
|
||||
};
|
||||
|
||||
|
||||
@@ -7,4 +7,3 @@ import type { ValidationError } from './ValidationError';
|
||||
export type HTTPValidationError = {
|
||||
detail?: Array<ValidationError>;
|
||||
};
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user