Compare commits

...

60 Commits

Author SHA1 Message Date
Sergey Borisov
e4a45341c8 Controlnet implementation for sequential execution 2023-06-16 02:42:32 +03:00
blessedcoolant
4ca325e8e6 chore: Rebuild API 2023-06-15 03:20:49 +12:00
blessedcoolant
6b8e88ad7f Merge branch 'main' into feat/controlnet-control-modes 2023-06-15 03:18:41 +12:00
psychedelicious
0497bea264 fix: add dynamicprompts to pyproject.toml 2023-06-15 01:05:16 +10:00
psychedelicious
b8e32fa459 chore(ui): regen api client 2023-06-15 01:05:16 +10:00
psychedelicious
34ebee67b7 fix(nodes): fix revert conflict 2023-06-15 01:05:16 +10:00
psychedelicious
e0c998d192 Revert "feat(ui): add warning socket event handling"
This reverts commit e7a61e631a42190e4b64e0d5e22771c669c5b30c.
2023-06-15 01:05:16 +10:00
psychedelicious
b51e9a6bdb Revert "feat(nodes): add warning socket event"
This reverts commit cefdd9d634e515239bd85666c872a0d64bb9d772.
2023-06-15 01:05:16 +10:00
psychedelicious
09f396ce84 feat(ui): add warning socket event handling 2023-06-15 01:05:16 +10:00
psychedelicious
abee37eab3 feat(nodes): add warning socket event 2023-06-15 01:05:16 +10:00
psychedelicious
42e48b2bef feat(nodes): add dynamic prompt node 2023-06-15 01:05:16 +10:00
blessedcoolant
70ece4364c refactor(minor): Image & Latent File Storage (#3538)
- `DiskImageStorage` and `DiskLatentsStorage` have now both been updated
to exclusively work with `Path` objects and not rely on the `os` lib to
handle pathing related functions.
- We now also validate the existence of the required image output
folders and latent output folders to ensure that the app does not break
in case the required folders get tampered with mid-session.
- Just overall general cleanup.

Tested it. Don't seem to be any thing breaking.
2023-06-15 02:43:27 +12:00
psychedelicious
f9d5f9d52c fix(nodes): minor fixes for folder validation
- fix type for `__output_folder`
- prefix `validate_storage_folders()` with `__` to indicate private method
2023-06-15 00:40:39 +10:00
blessedcoolant
587297878a refactor(minor): Latent Disk Storage 2023-06-15 02:21:49 +12:00
blessedcoolant
b4c998a9ae refactor(minor): Image File Storage 2023-06-15 01:58:58 +12:00
psychedelicious
88e8e3977b feat(ui): update UI to not use image_origin
see commit `8ad8de8: feat(nodes): remove `image_origin` from most places` for details.
2023-06-14 23:08:27 +10:00
psychedelicious
24b86cffe9 chore(ui): regen api client & types 2023-06-14 23:08:27 +10:00
psychedelicious
a1773197e9 feat(nodes): remove image_origin from most places
- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
2023-06-14 23:08:27 +10:00
blessedcoolant
6c53abc034 feat: Add ControlMode to Linear UI 2023-06-14 20:01:17 +12:00
blessedcoolant
eb7047b21d chore: Rebuild WebAPI 2023-06-14 19:26:02 +12:00
blessedcoolant
43419ac761 Merge branch 'main' into feat/controlnet-control-modes 2023-06-14 19:04:42 +12:00
user1
5cd0e90816 Renamed ControlNet control_mode option "even_more_control" to "unbalanced" 2023-06-13 22:30:17 -07:00
user1
cfd49e3921 Removing vestigial comments. 2023-06-13 21:33:15 -07:00
user1
a8e0490133 Merge branch 'feat/controlnet-control-modes' of https://github.com/invoke-ai/InvokeAI into feat/controlnet-control-modes 2023-06-13 21:21:13 -07:00
psychedelicious
1e08d865c9 chore: dummy commit to trigger actions 2023-06-14 14:14:24 +10:00
blessedcoolant
f8bb650cc1 revert: IAIScrollArea 2023-06-14 14:14:24 +10:00
psychedelicious
2cee8bebb2 fix(ui): revert offset scrollbars
The wonky padding is too janky. Just overlay for now.
2023-06-14 14:14:24 +10:00
psychedelicious
ade4ec5fd8 fix(ui): fix crash when toggling pinned parameters panel 2023-06-14 14:14:24 +10:00
psychedelicious
70ffd6b03f fix(ui): fix controlnet selects data types 2023-06-14 14:14:24 +10:00
psychedelicious
6c551df311 fix(ui): fix rebase conflicts 2023-06-14 14:14:24 +10:00
blessedcoolant
24f605629e cleanup: Remove OverlayScrollable component 2023-06-14 14:14:24 +10:00
blessedcoolant
2af1ec9d02 fix: Minor padding issue in unpinned drawer 2023-06-14 14:14:24 +10:00
blessedcoolant
79d53341de fix: Stretch scroll area so it retains parent width 2023-06-14 14:14:24 +10:00
blessedcoolant
e40b3506c4 fix: Options squishing on accordion collapse 2023-06-14 14:14:24 +10:00
blessedcoolant
33912382e3 feat: Introduce Mantine's ScrollArea 2023-06-14 14:14:24 +10:00
blessedcoolant
d282810e53 cleanup: Remove IAICustomSelect and port types 2023-06-14 14:14:24 +10:00
psychedelicious
9df502fc77 fix(ui): fix mantine select props 2023-06-14 14:14:24 +10:00
psychedelicious
705573f0a8 feat(ui): even more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
blessedcoolant
1878ea94f6 feat: Port Canvas Layer Select to IAIMantineSelect 2023-06-14 14:14:24 +10:00
psychedelicious
4ba5086b9a feat(ui): add tooltip to IAIMantineSelect 2023-06-14 14:14:24 +10:00
psychedelicious
4a991b4daa feat(ui): more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
psychedelicious
80474d26f9 feat(ui): mantine scrollbar theming 2023-06-14 14:14:24 +10:00
blessedcoolant
9a77bd9140 feat: Port IAISelect's to IAIMantineSelect's
Ported everything except Model Manager selects and the Canvas Layer Select (this needs tooltip support)
2023-06-14 14:14:24 +10:00
psychedelicious
14cdc800c3 feat(ui): pedantic mantine select theming 2023-06-14 14:14:24 +10:00
blessedcoolant
9cfbea4c25 feat: Match styling of Mantine Select with InvokeAI 2023-06-14 14:14:24 +10:00
blessedcoolant
5fe674e223 feat: Standardize IAIMantineSelect Component 2023-06-14 14:14:24 +10:00
blessedcoolant
32200efce8 feat: Change default font to Inter 2023-06-14 14:14:24 +10:00
blessedcoolant
68a02da990 feat: Use Mantine Select for Scheduler 2023-06-14 14:14:24 +10:00
blessedcoolant
5b20766ea3 chore: Move Mantine Theme Override to own file 2023-06-14 14:14:24 +10:00
blessedcoolant
9a914250a0 feat: Change Model Select To Mantine 2023-06-14 14:14:24 +10:00
blessedcoolant
0e3106f631 feat: Add Mantine Support 2023-06-14 14:14:24 +10:00
user1
de3e6cdb02 Switched over to ControlNet control_mode with 4 options: balanced, more_prompt, more_control, even_more_control. Based on True/False combinations of internal booleans cfg_injection and soft_injection 2023-06-13 21:08:34 -07:00
user1
8495764d45 Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes 2023-06-13 00:41:36 -07:00
user1
8b7fac75ed First pass at ControlNet "guess mode" implementation. 2023-06-13 00:41:36 -07:00
user1
9e0e26f4c4 Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes 2023-06-12 23:57:57 -07:00
blessedcoolant
46cac6468e Upgrade to Diffusers 0.17.0 (#3514)
Diffusers is due for an update soon. #3512

Opening up a PR now with the required changes for when the new version
is live.

I've tested it out on Windows and nothing has broken from what I could
tell. I'd like someone to run some tests on Linux / Mac just to make
sure. Refer to the PR above on how to test it or install the release
branch.

```
pip install diffusers[torch]==0.17.0
```

Feel free to push any other changes to this PR you see fit.
2023-06-13 07:11:02 +12:00
blessedcoolant
2a814d886b Merge branch 'main' into diffusers-upgrade 2023-06-13 05:29:15 +12:00
user1
fd715026a7 First pass at ControlNet "guess mode" implementation. 2023-06-11 02:00:39 -07:00
blessedcoolant
7bce455d16 Merge branch 'main' into diffusers-upgrade 2023-06-09 16:27:52 +12:00
blessedcoolant
68405910ba Upgrade to Diffusers 0.17.0 2023-06-08 04:42:52 +12:00
190 changed files with 1649 additions and 2047 deletions

View File

@@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@images_router.patch(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
@@ -99,32 +97,29 @@ async def update_image(
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_origin}/{image_name}/metadata",
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@@ -136,15 +131,12 @@ async def get_image_metadata(
},
)
async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get(
"/{image_origin}/{image_name}/thumbnail",
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@@ -172,14 +164,13 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_origin}/{image_name}/urls",
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name
)
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@@ -1,7 +1,7 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
from builtins import float, bool
import numpy as np
from typing import Literal, Optional, Union, List
@@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
@@ -104,6 +105,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
@@ -144,11 +147,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
# fmt: on
class Config(InvocationConfig):
@@ -166,7 +169,6 @@ class ControlNetInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
@@ -174,6 +176,7 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
),
)
@@ -193,9 +196,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
raw_image = context.services.images.get_pil_image(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@@ -216,10 +217,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,

View File

@@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
mask = context.services.images.get_pil_image(self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
@@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
else context.services.images.get_pil_image(self.control_image.image_name)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
@@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
if self.fit:
@@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_name)
)
# Handle invalid model parameter
@@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
@@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
return MaskOutput(
mask=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
image1 = context.services.images.get_pil_image(self.image1.image_name)
image2 = context.services.images.get_pil_image(self.image2.image_name)
multiply_image = ImageChops.multiply(image1, image2)
@@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
channel_image = image.getchannel(self.channel)
@@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
converted_image = image.convert(self.mode)
@@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius)
@@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
@@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
@@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = (
@@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
@@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
@@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -5,7 +5,7 @@ import einops
from typing import Literal, Optional, Union, List
from compel import Compel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from pydantic import BaseModel, Field, validator
import torch
@@ -282,19 +282,14 @@ class TextToLatentsInvocation(BaseInvocation):
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
@@ -321,8 +316,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@@ -337,12 +331,15 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@@ -502,10 +499,7 @@ class LatentsToImageInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -601,9 +595,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)

View File

@@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt',
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
@@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
@@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_origin", "image_name"]}
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@@ -40,14 +39,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@@ -62,7 +59,6 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@@ -71,7 +67,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@@ -79,31 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_origin, image_name)
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@@ -117,13 +108,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_origin, image_name)
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -133,7 +124,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@@ -142,20 +133,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename)
image_path = self.get_path(image_name)
if os.path.exists(image_path):
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
@@ -163,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
abspath = os.path.abspath(path)
return path
return abspath
def validate_path(self, path: str) -> bool:
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: str) -> PILImageType | None:
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
@@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
@@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image record."""
pass
@@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
def get(self, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
@@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
@@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(

View File

@@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
@abstractmethod
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
@@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
def get_path(self, image_name: str) -> str:
"""Gets an image's path."""
pass
@@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
"""Deletes an image."""
pass
@@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
)
self._services.files.save(
image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_origin, image_name, True
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO(
# Non-nullable fields
@@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
self._services.records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
return self._services.files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_origin, image_name)
return self._services.records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_origin, image_name)
image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
)
return image_dto
@@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_origin, image_name, thumbnail)
return self._services.files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_origin, r.image_name, True
),
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
results.items,
)
@@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
try:
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
self._services.files.delete(image_name)
self._services.records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")

View File

@@ -1,17 +1,12 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
return f"{self._base_url}/images/{image_basename}"

View File

@@ -6,7 +6,7 @@ import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from ..stable_diffusion import (
ConditioningData,

View File

@@ -23,7 +23,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
@@ -46,6 +46,7 @@ from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
ControlNetData,
)
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@@ -214,13 +215,6 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True)
class ConditioningData:
@@ -656,69 +650,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,
)
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,
t,
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
conditioning_data.guidance_scale,
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
control_data=control_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
)
# compute the previous noisy sample x_t -> x_t-1
@@ -1038,6 +981,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
@@ -1068,6 +1012,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
#cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
#if do_classifier_free_guidance and not cfg_injection:
# image = torch.cat([image] * 2)
return image

View File

@@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, ControlNetData

View File

@@ -1,11 +1,14 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pydantic import Field
from math import ceil
from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np
import torch
import math
from diffusers import UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
@@ -40,6 +43,17 @@ class PostprocessingSettings:
v_symmetry_time_pct: Optional[float]
# TODO: pydantic Field work with dataclasses?
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@@ -182,8 +196,9 @@ class InvokeAIDiffuserComponent:
conditioning: Union[torch.Tensor, dict],
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
step_index: int,
total_step_count: int,
control_data: Optional[List[ControlNetData]],
**kwargs,
):
"""
@@ -213,31 +228,56 @@ class InvokeAIDiffuserComponent:
)
)
if self.sequential_guidance:
down_block_res_samples, mid_block_res_sample = self._run_controlnet_sequentially(
unconditioning=unconditioning,
conditioning=conditioning,
control_data=control_data,
sample=x,
timestep=sigma,
step_index=step_index,
total_step_count=total_step_count,
)
else:
down_block_res_samples, mid_block_res_sample = self._run_controlnet_normally(
unconditioning=unconditioning,
conditioning=conditioning,
control_data=control_data,
sample=x,
timestep=sigma,
step_index=step_index,
total_step_count=total_step_count,
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
elif wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
elif self.sequential_guidance:
elif True: #self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
else:
@@ -245,7 +285,10 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
combined_next_x = self._combine(
@@ -293,16 +336,160 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
def _run_controlnet_normally(
self,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
control_data: List[ControlNetData],
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
):
if control_data is None:
return (None, None)
down_block_res_samples, mid_block_res_sample = None, None
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_sample = sample
control_timestep = timestep
control_image_tensor = control_datum.image_tensor
encoder_hidden_states = conditioning # TODO: ask bug
else:
control_sample = torch.cat([sample] * 2)
control_timestep = torch.cat([timestep] * 2)
control_image_tensor = torch.cat([control_datum.image_tensor] * 2)
encoder_hidden_states = torch.cat([unconditioning, conditioning])
if isinstance(control_datum.weight, list):
weight = control_datum.weight[step_index]
else:
weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_sample,
timestep=control_timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def _run_controlnet_sequentially(
self,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
control_data: List[ControlNetData],
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
):
if control_data is None:
return (None, None)
down_block_res_samples, mid_block_res_sample = None, None
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if isinstance(control_datum.weight, list):
weight = control_datum.weight[step_index]
else:
weight = control_datum.weight
# controlnet(s) inference
cond_down_samples, cond_mid_sample = control_datum.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=conditioning, # TODO: ask bug
controlnet_cond=control_datum.image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
uncond_down_samples = [torch.zeros_like(d) for d in cond_down_samples]
uncond_mid_sample = torch.zeros_like(cond_mid_sample)
else:
uncond_down_samples, uncond_mid_sample = control_datum.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=unconditioning,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
down_samples = [torch.cat([ud, cd]) for ud, cd in zip(uncond_down_samples, cond_down_samples)]
mid_sample = torch.cat([uncond_mid_sample, cond_mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def _apply_standard_conditioning(
self,
x: torch.Tensor,
sigma: torch.Tensor,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
**kwargs
):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
)
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings, **kwargs)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
# TODO: check if this still present
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
@@ -310,15 +497,43 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning_sequentially(
self,
x: torch.Tensor,
sigma,
sigma: torch.Tensor,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
down_block_additional_residuals, # from controlnet(s)
mid_block_additional_residual, # from controlnet(s)
**kwargs,
):
# split controlnet data to cond and uncond
if down_block_additional_residuals is None:
uncond_down_block_res_samples = None
cond_down_block_res_samples = None
uncond_mid_block_res_sample = None
cond_mid_block_res_sample = None
else:
uncond_down_block_res_samples = []
cond_down_block_res_samples = []
for d in down_block_additional_residuals:
ud, cd = d.chunk(2)
uncond_down_block_res_samples.append(ud)
cond_down_block_res_samples.append(cd)
uncond_mid_block_res_sample, cond_mid_block_res_sample = mid_block_additional_residual.chunk(2)
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning, **kwargs,
down_block_additional_residuals=uncond_down_block_res_samples,
mid_block_additional_residual=uncond_mid_block_res_sample,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning, **kwargs,
down_block_additional_residuals=cond_down_block_res_samples,
mid_block_additional_residual=cond_mid_block_res_sample,
)
if conditioned_next_x.device.type == "mps":
# TODO: check if still present
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x

View File

@@ -62,10 +62,12 @@
"@dagrejs/graphlib": "^2.1.12",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
"@emotion/react": "^11.10.6",
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.10.6",
"@floating-ui/react-dom": "^2.0.0",
"@fontsource/inter": "^4.5.15",
"@mantine/core": "^6.0.13",
"@mantine/hooks": "^6.0.13",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",

View File

@@ -524,7 +524,8 @@
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
},
"settings": {
"models": "Models",

View File

@@ -3,11 +3,11 @@ import {
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
@@ -15,6 +15,8 @@ import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/variable.css';
import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
@@ -51,9 +53,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
}, [direction]);
return (
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
<MantineProvider withGlobalStyles theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
</MantineProvider>
);
}

View File

@@ -22,9 +22,9 @@ export const SCHEDULERS = [
export type Scheduler = (typeof SCHEDULERS)[number];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
{ label: '2x', value: '2' },
{ label: '4x', value: '4' },
];
export const NUMPY_RAND_MIN = 0;

View File

@@ -34,10 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, [
'image_name',
'image_origin',
]),
image: pick(controlNet.controlImage, ['image_name']),
},
},
};

View File

@@ -25,7 +25,7 @@ export const addRequestedImageDeletionListener = () => {
effect: (action, { dispatch, getState }) => {
const { image, imageUsage } = action.payload;
const { image_name, image_origin } = image;
const { image_name } = image;
const state = getState();
const selectedImage = state.gallery.selectedImage;
@@ -79,9 +79,7 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name));
// Delete from server
dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
dispatch(imageDeleted({ imageName: image_name }));
},
});
};

View File

@@ -20,7 +20,6 @@ export const addImageMetadataReceivedFulfilledListener = () => {
dispatch(
imageUpdated({
imageName: image.image_name,
imageOrigin: image.image_origin,
requestBody: { is_intermediate: false },
})
);

View File

@@ -36,13 +36,12 @@ export const addInvocationCompleteEventListener = () => {
// This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_origin } = result.image;
const { image_name } = result.image;
// Get its metadata
dispatch(
imageMetadataReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);

View File

@@ -11,12 +11,11 @@ export const addStagingAreaImageSavedListener = () => {
startAppListening({
actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState, take }) => {
const { image_name, image_origin } = action.payload;
const { image_name } = action.payload;
dispatch(
imageUpdated({
imageName: image_name,
imageOrigin: image_origin,
requestBody: {
is_intermediate: false,
},

View File

@@ -80,11 +80,10 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
allUsedImages.forEach(({ image_name, image_origin }) => {
allUsedImages.forEach(({ image_name }) => {
dispatch(
imageUrlsReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);
});

View File

@@ -116,7 +116,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type
baseNode.image = {
image_name: baseImageDTO.image_name,
image_origin: baseImageDTO.image_origin,
};
}
@@ -143,7 +142,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type
baseNode.mask = {
image_name: maskImageDTO.image_name,
image_origin: maskImageDTO.image_origin,
};
}
@@ -160,7 +158,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch(
imageUpdated({
imageName: baseNode.image.image_name,
imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId },
})
);
@@ -171,7 +168,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch(
imageUpdated({
imageName: baseNode.mask.image_name,
imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId },
})
);

View File

@@ -1,256 +0,0 @@
import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import {
Box,
Flex,
FormControl,
FormControlProps,
FormLabel,
Grid,
GridItem,
List,
ListItem,
Text,
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useLayoutEffect, useMemo } from 'react';
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string };
export type IAICustomSelectOption = {
value: string;
label: string;
tooltip?: string;
};
type IAICustomSelectProps = {
label?: string;
value: string;
data: IAICustomSelectOption[] | string[];
onChange: (v: string) => void;
withCheckIcon?: boolean;
formControlProps?: FormControlProps;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end';
isDisabled?: boolean;
};
const IAICustomSelect = (props: IAICustomSelectProps) => {
const {
label,
withCheckIcon,
formControlProps,
tooltip,
tooltipProps,
ellipsisPosition = 'end',
data,
value,
onChange,
isDisabled = false,
} = props;
const values = useMemo(() => {
return data.map<IAICustomSelectOption>((v) => {
if (isString(v)) {
return { value: v, label: v };
}
return v;
});
}, [data]);
const stringValues = useMemo(() => {
return values.map((v) => v.value);
}, [values]);
const valueData = useMemo(() => {
return values.find((v) => v.value === value);
}, [values, value]);
const {
isOpen,
getToggleButtonProps,
getLabelProps,
getMenuProps,
highlightedIndex,
getItemProps,
} = useSelect({
items: stringValues,
selectedItem: value,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
newSelectedItem && onChange(newSelectedItem);
},
});
const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
// whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
});
useLayoutEffect(() => {
if (isOpen && refs.reference.current && refs.floating.current) {
return autoUpdate(refs.reference.current, refs.floating.current, update);
}
}, [isOpen, update, refs.floating, refs.reference]);
const labelTextDirection = useMemo(() => {
if (ellipsisPosition === 'start') {
return document.dir === 'rtl' ? 'ltr' : 'rtl';
}
return document.dir;
}, [ellipsisPosition]);
return (
<FormControl sx={{ w: 'full' }} {...formControlProps}>
{label && (
<FormLabel
{...getLabelProps()}
onClick={() => {
refs.floating.current && refs.floating.current.focus();
}}
>
{label}
</FormLabel>
)}
<Tooltip label={tooltip} {...tooltipProps}>
<Flex
{...getToggleButtonProps({ ref: refs.reference })}
sx={{
alignItems: 'center',
userSelect: 'none',
cursor: 'pointer',
overflow: 'hidden',
width: 'full',
py: 1,
px: 2,
gap: 2,
justifyContent: 'space-between',
pointerEvents: isDisabled ? 'none' : undefined,
opacity: isDisabled ? 0.5 : undefined,
...getInputOutlineStyles(),
}}
>
<Text
sx={{
fontSize: 'sm',
fontWeight: 500,
color: 'base.100',
whiteSpace: 'nowrap',
overflow: 'hidden',
textOverflow: 'ellipsis',
direction: labelTextDirection,
}}
>
{valueData?.label}
</Text>
<ChevronUpIcon
sx={{
color: 'base.300',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
</Flex>
</Tooltip>
<Box {...getMenuProps()}>
{isOpen && (
<List
as={Flex}
ref={refs.floating}
sx={{
...floatingStyles,
top: 0,
insetInlineStart: 0,
flexDirection: 'column',
zIndex: 2,
bg: 'base.800',
borderRadius: 'base',
border: '1px',
borderColor: 'base.700',
shadow: 'dark-lg',
py: 2,
px: 0,
h: 'fit-content',
maxH: 64,
minW: 48,
}}
>
<OverlayScrollbarsComponent>
{values.map((v, index) => {
const isSelected = value === v.value;
const isHighlighted = highlightedIndex === index;
const fontWeight = isSelected ? 700 : 500;
const bg = isHighlighted
? 'base.700'
: isSelected
? 'base.750'
: undefined;
return (
<Tooltip
isDisabled={!v.tooltip}
key={`${v.value}${index}`}
label={v.tooltip}
hasArrow
placement="right"
>
<ListItem
sx={{
bg,
py: 1,
paddingInlineStart: 3,
paddingInlineEnd: 6,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: '0.15s',
}}
{...getItemProps({ item: v.value, index })}
>
{withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto">
<GridItem>
{isSelected && <CheckIcon boxSize={2} />}
</GridItem>
<GridItem>
<Text
sx={{
fontSize: 'sm',
color: 'base.100',
fontWeight,
}}
>
{v.label}
</Text>
</GridItem>
</Grid>
) : (
<Text
sx={{
fontSize: 'sm',
color: 'base.50',
fontWeight,
}}
>
{v.label}
</Text>
)}
</ListItem>
</Tooltip>
);
})}
</OverlayScrollbarsComponent>
</List>
)}
</Box>
</FormControl>
);
};
export default memo(IAICustomSelect);

View File

@@ -0,0 +1,76 @@
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { memo } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = SelectProps & {
tooltip?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
searchable={searchable}
styles={() => ({
label: {
color: 'var(--invokeai-colors-base-300)',
fontWeight: 'normal',
},
input: {
backgroundColor: 'var(--invokeai-colors-base-900)',
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
borderColor: 'var(--invokeai-colors-base-700)',
},
item: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-200)',
padding: 6,
'&[data-hovered]': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
'&[data-active]': {
backgroundColor: 'var(--invokeai-colors-base-750)',
'&:hover': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
},
'&[data-selected]': {
color: 'var(--invokeai-colors-base-50)',
backgroundColor: 'var(--invokeai-colors-accent-650)',
fontWeight: 600,
'&:hover': {
backgroundColor: 'var(--invokeai-colors-accent-600)',
},
},
},
rightSection: {
width: 24,
},
})}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSelect);

View File

@@ -0,0 +1,12 @@
import { ScrollArea, ScrollAreaProps } from '@mantine/core';
type IAIScrollArea = ScrollAreaProps;
export default function IAIScrollArea(props: IAIScrollArea) {
const { ...rest } = props;
return (
<ScrollArea w="100%" {...rest}>
{props.children}
</ScrollArea>
);
}

View File

@@ -2,7 +2,6 @@ import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISelect from 'common/components/IAISelect';
import useImageUploader from 'common/hooks/useImageUploader';
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
import {
@@ -25,7 +24,13 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
canvasMerged,
canvasSavedToGallery,
} from 'features/canvas/store/actions';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@@ -43,12 +48,6 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
canvasMerged,
canvasSavedToGallery,
} from 'features/canvas/store/actions';
export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector],
@@ -197,8 +196,8 @@ const IAICanvasToolbar = () => {
dispatch(canvasDownloadedAsImage());
};
const handleChangeLayer = (e: ChangeEvent<HTMLSelectElement>) => {
const newLayer = e.target.value as CanvasLayer;
const handleChangeLayer = (v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
@@ -214,13 +213,12 @@ const IAICanvasToolbar = () => {
}}
>
<Box w={24}>
<IAISelect
<IAIMantineSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
tooltipProps={{ hasArrow: true, placement: 'top' }}
value={layer}
validValues={LAYER_NAMES_DICT}
data={LAYER_NAMES_DICT}
onChange={handleChangeLayer}
isDisabled={isStaging}
disabled={isStaging}
/>
</Box>

View File

@@ -866,8 +866,7 @@ export const canvasSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
state.layerState.objects.forEach((object) => {
if (object.kind === 'image') {

View File

@@ -4,8 +4,8 @@ import { RgbaColor } from 'react-colorful';
import { ImageDTO } from 'services/api';
export const LAYER_NAMES_DICT = [
{ key: 'Base', value: 'base' },
{ key: 'Mask', value: 'mask' },
{ label: 'Base', value: 'base' },
{ label: 'Mask', value: 'mask' },
];
export const LAYER_NAMES = ['base', 'mask'] as const;
@@ -13,9 +13,9 @@ export const LAYER_NAMES = ['base', 'mask'] as const;
export type CanvasLayer = (typeof LAYER_NAMES)[number];
export const BOUNDING_BOX_SCALES_DICT = [
{ key: 'Auto', value: 'auto' },
{ key: 'Manual', value: 'manual' },
{ key: 'None', value: 'none' },
{ label: 'Auto', value: 'auto' },
{ label: 'Manual', value: 'manual' },
{ label: 'None', value: 'none' },
];
export const BOUNDING_BOX_SCALES = ['none', 'auto', 'manual'] as const;

View File

@@ -1,26 +1,27 @@
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetAdded,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import IAISwitch from 'common/components/IAISwitch';
import { ChevronUpIcon } from '@chakra-ui/icons';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
@@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
@@ -137,45 +139,51 @@ const ControlNet = (props: ControlNetProps) => {
</Flex>
{isEnabled && (
<>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
h: isExpanded ? 28 : 24,
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
flexDir: 'column',
gap: 3,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
)}
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
</Flex>
{isExpanded && (
<>
<Box mt={2}>

View File

@@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
{ label: 'Balanced', value: 'balanced' },
{ label: 'Prompt', value: 'more_prompt' },
{ label: 'Control', value: 'more_control' },
{ label: 'Mega Control', value: 'unbalanced' },
];
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}
/>
);
}

View File

@@ -1,9 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
@@ -12,7 +11,7 @@ import {
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo, useCallback } from 'react';
type ParamControlNetModelProps = {
controlNetId: string;
@@ -20,17 +19,18 @@ type ParamControlNetModelProps = {
};
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_MODELS, (m) => ({
key: m.label,
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
label: m.label,
value: m.type,
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
});
})).filter(
(d) =>
!config.sd.disabledControlNetModels.includes(
d.value as ControlNetModelName
)
);
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
// value: m.type,
// label: m.label,
// tooltip: m.type,
// }));
return controlNetModels;
});
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
@@ -39,47 +39,23 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const isReady = useIsReadyToInvoke();
const handleModelChanged = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
(val: string | null) => {
// TODO: do not cast
const model = e.target.value as ControlNetModelName;
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
},
[controlNetId, dispatch]
);
// const handleModelChanged = useCallback(
// (val: string | null | undefined) => {
// // TODO: do not cast
// const model = val as ControlNetModelName;
// dispatch(controlNetModelChanged({ controlNetId, model }));
// },
// [controlNetId, dispatch]
// );
return (
<IAISelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
validValues={controlNetModels}
<IAIMantineSelect
data={controlNetModels}
value={model}
onChange={handleModelChanged}
isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
disabled={!isReady}
tooltip={model}
/>
);
// return (
// <IAICustomSelect
// tooltip={model}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// data={DATA}
// value={model}
// onChange={handleModelChanged}
// isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
// />
// );
};
export default memo(ParamControlNetModel);

View File

@@ -1,64 +1,55 @@
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { ChangeEvent, memo, useCallback } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import IAISelect from 'common/components/IAISelect';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
})).sort((a, b) =>
// sort 'none' to the top
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
);
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
}))
.sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.key.localeCompare(b.key)
const selector = createSelector(
configSelector,
(config) => {
const controlNetProcessors: IAISelectDataType[] = map(
CONTROLNET_PROCESSORS,
(p) => ({
value: p.type,
label: p.label,
})
)
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
});
.sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
)
.filter(
(d) =>
!config.sd.disabledControlNetProcessors.includes(
d.value as ControlNetProcessorType
)
);
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
// CONTROLNET_PROCESSORS,
// (p) => ({
// value: p.type,
// label: p.label,
// tooltip: p.description,
// })
// ).sort((a, b) =>
// // sort 'none' to the top
// a.value === 'none'
// ? -1
// : b.value === 'none'
// ? 1
// : a.label.localeCompare(b.label)
// );
return controlNetProcessors;
},
defaultSelectorOptions
);
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
@@ -69,47 +60,26 @@ const ParamControlNetProcessorSelect = (
const controlNetProcessors = useAppSelector(selector);
const handleProcessorTypeChanged = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
(v: string | null) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: e.target.value as ControlNetProcessorType,
processorType: v as ControlNetProcessorType,
})
);
},
[controlNetId, dispatch]
);
// const handleProcessorTypeChanged = useCallback(
// (v: string | null | undefined) => {
// dispatch(
// controlNetProcessorTypeChanged({
// controlNetId,
// processorType: v as ControlNetProcessorType,
// })
// );
// },
// [controlNetId, dispatch]
// );
return (
<IAISelect
<IAIMantineSelect
label="Processor"
value={processorNode.type ?? 'canny_image_processor'}
validValues={controlNetProcessors}
data={controlNetProcessors}
onChange={handleProcessorTypeChanged}
isDisabled={!isReady}
disabled={!isReady}
/>
);
// return (
// <IAICustomSelect
// label="Processor"
// value={processorNode.type ?? 'canny_image_processor'}
// data={CONTROLNET_PROCESSOR_TYPES}
// onChange={handleProcessorTypeChanged}
// withCheckIcon
// isDisabled={!isReady}
// />
// );
};
export default memo(ParamControlNetProcessorSelect);

View File

@@ -1,6 +1,5 @@
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
@@ -23,7 +22,7 @@ type ControlNetProcessorsDict = Record<
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS = {
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: {
type: 'none',
label: 'none',
@@ -174,6 +173,8 @@ export const CONTROLNET_PROCESSORS = {
},
};
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = {
type: string;
label: string;
@@ -181,7 +182,7 @@ type ControlNetModel = {
defaultProcessor?: ControlNetProcessorType;
};
export const CONTROLNET_MODELS = {
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
@@ -190,6 +191,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
@@ -209,6 +211,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
@@ -223,6 +226,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
@@ -242,10 +246,12 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',

View File

@@ -1,36 +1,27 @@
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { forEach } from 'lodash-es';
import { ImageDTO } from 'services/api';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import { appSocketInvocationError } from 'services/events/actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { isAnySessionRejected } from 'services/thunks/session';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export type ControlModes =
| 'balanced'
| 'more_prompt'
| 'more_control'
| 'unbalanced';
export type ControlNetConfig = {
controlNetId: string;
@@ -39,6 +30,7 @@ export type ControlNetConfig = {
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
processorType: ControlNetProcessorType;
@@ -46,6 +38,21 @@ export type ControlNetConfig = {
shouldAutoConfig: boolean;
};
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
@@ -147,11 +154,13 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
const processorType =
CONTROLNET_MODELS[model as keyof typeof CONTROLNET_MODELS]
.defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
processorType as keyof typeof CONTROLNET_PROCESSORS
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
@@ -181,6 +190,13 @@ export const controlNetSlice = createSlice({
const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct;
},
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
@@ -210,7 +226,7 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
processorType as keyof typeof CONTROLNET_PROCESSORS
].default as RequiredControlNetProcessorNode;
state.controlNets[controlNetId].shouldAutoConfig = false;
},
@@ -227,12 +243,14 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
CONTROLNET_MODELS[
state.controlNets[controlNetId]
.model as keyof typeof CONTROLNET_MODELS
].defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
processorType as keyof typeof CONTROLNET_PROCESSORS
].default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
@@ -271,8 +289,7 @@ export const controlNetSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === image_name) {
@@ -286,11 +303,11 @@ export const controlNetSlice = createSlice({
});
});
builder.addCase(appSocketInvocationError, (state, action) => {
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
builder.addMatcher(isAnySessionRejected, (state, action) => {
builder.addMatcher(isAnySessionRejected, (state) => {
state.pendingControlImages = [];
});
},
@@ -308,6 +325,7 @@ export const {
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,

View File

@@ -9,15 +9,15 @@ import {
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { ImageDTO } from 'services/api';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
type MetadataItemProps = {
isLink?: boolean;
@@ -324,7 +324,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
w: 'max-content',
w: 'full',
}}
>
<pre>{metadataJSON}</pre>

View File

@@ -59,8 +59,7 @@ export const gallerySlice = createSlice({
}
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.selectedImage?.image_name === image_name) {
state.selectedImage.image_url = image_url;

View File

@@ -86,8 +86,7 @@ const imagesSlice = createSlice({
imagesAdapter.removeOne(state, imageName);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
imagesAdapter.updateOne(state, {
id: image_name,

View File

@@ -103,8 +103,7 @@ const nodesSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {

View File

@@ -45,6 +45,7 @@ export const addControlNetToLinearGraph = (
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
model,
processorType,
weight,
@@ -60,23 +61,22 @@ export const addControlNetToLinearGraph = (
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
const { image_name } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
const { image_name } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@@ -354,7 +354,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
type: 'img_resize',
image: {
image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
},
is_intermediate: true,
height,
@@ -392,7 +391,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
});
// Pass the image's dimensions to the `NOISE` node

View File

@@ -57,8 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
image_name: initialImage.name,
image_origin: initialImage.type,
image_name: initialImage.image_name,
};
}

View File

@@ -1,12 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
@@ -30,17 +30,17 @@ const ParamInfillMethod = () => {
const { t } = useTranslation();
const handleChange = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(setInfillMethod(e.target.value));
(v: string) => {
dispatch(setInfillMethod(v));
},
[dispatch]
);
return (
<IAISelect
<IAIMantineSelect
label={t('parameters.infillMethod')}
value={infillMethod}
validValues={infillMethods}
data={infillMethods}
onChange={handleChange}
/>
);

View File

@@ -1,15 +1,15 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import {
BoundingBoxScale,
BOUNDING_BOX_SCALES_DICT,
BoundingBoxScale,
} from 'features/canvas/store/canvasTypes';
import { ChangeEvent, memo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
@@ -30,16 +30,14 @@ const ParamScaleBeforeProcessing = () => {
const { t } = useTranslation();
const handleChangeBoundingBoxScaleMethod = (
e: ChangeEvent<HTMLSelectElement>
) => {
dispatch(setBoundingBoxScaleMethod(e.target.value as BoundingBoxScale));
const handleChangeBoundingBoxScaleMethod = (v: string) => {
dispatch(setBoundingBoxScaleMethod(v as BoundingBoxScale));
};
return (
<IAISelect
<IAIMantineSelect
label={t('parameters.scaleBeforeProcessing')}
validValues={BOUNDING_BOX_SCALES_DICT}
data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale}
onChange={handleChangeBoundingBoxScaleMethod}
/>

View File

@@ -2,23 +2,20 @@ import { createSelector } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
// but we need to wait for the next release before removing this special handling.
const allSchedulers = ui.schedulers
.filter((scheduler) => {
return !['dpmpp_2s'].includes(scheduler);
})
const allSchedulers: string[] = ui.schedulers
.slice()
.sort((a, b) => a.localeCompare(b));
return {
@@ -36,39 +33,23 @@ const ParamScheduler = () => {
const { t } = useTranslation();
const handleChange = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(setScheduler(e.target.value as Scheduler));
(v: string | null) => {
if (!v) {
return;
}
dispatch(setScheduler(v as Scheduler));
},
[dispatch]
);
// const handleChange = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(setScheduler(v as Scheduler));
// },
// [dispatch]
// );
return (
<IAISelect
<IAIMantineSelect
label={t('parameters.scheduler')}
value={scheduler}
validValues={allSchedulers}
data={allSchedulers}
onChange={handleChange}
/>
);
// return (
// <IAICustomSelect
// label={t('parameters.scheduler')}
// value={scheduler}
// data={allSchedulers}
// onChange={handleChange}
// withCheckIcon
// />
// );
};
export default memo(ParamScheduler);

View File

@@ -1,12 +1,11 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
FacetoolType,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreType() {
@@ -17,13 +16,13 @@ export default function FaceRestoreType() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeFacetoolType = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setFacetoolType(e.target.value as FacetoolType));
const handleChangeFacetoolType = (v: string) =>
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAISelect
<IAIMantineSelect
label={t('parameters.type')}
validValues={FACETOOL_TYPES.concat()}
data={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>

View File

@@ -1,12 +1,11 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
setUpscalingLevel,
UpscalingLevel,
setUpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import type { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function UpscaleScale() {
@@ -21,16 +20,16 @@ export default function UpscaleScale() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
const handleChangeLevel = (v: string) =>
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAISelect
isDisabled={!isESRGANAvailable}
<IAIMantineSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={upscalingLevel}
value={String(upscalingLevel)}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
data={UPSCALING_LEVELS}
/>
);
}

View File

@@ -1,11 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es';
import { ImageDTO, ResourceOrigin } from 'services/api';
export type ImageNameAndOrigin = {
image_name: string;
image_origin: ResourceOrigin;
};
import { ImageDTO } from 'services/api';
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected'

View File

@@ -234,8 +234,7 @@ export const generationSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url;

View File

@@ -1,17 +1,16 @@
import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent, memo, useCallback } from 'react';
import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
@@ -19,18 +18,11 @@ const selector = createSelector(
const selectedModel = selectModelsById(state, generation.model);
const modelData = selectModelsAll(state)
.map((m) => ({
.map<IAISelectDataType>((m) => ({
value: m.name,
key: m.name,
label: m.name,
}))
.sort((a, b) => a.key.localeCompare(b.key));
// const modelData = selectModelsAll(state)
// .map<IAICustomSelectOption>((m) => ({
// value: m.name,
// label: m.name,
// tooltip: m.description,
// }))
// .sort((a, b) => a.label.localeCompare(b.label));
.sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
modelData,
@@ -48,43 +40,25 @@ const ModelSelect = () => {
const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector);
const handleChangeModel = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(modelSelected(e.target.value));
(v: string | null) => {
if (!v) {
return;
}
dispatch(modelSelected(v));
},
[dispatch]
);
// const handleChangeModel = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(modelSelected(v));
// },
// [dispatch]
// );
return (
<IAISelect
label={t('modelManager.model')}
<IAIMantineSelect
tooltip={selectedModel?.description}
validValues={modelData}
label={t('modelManager.model')}
value={selectedModel?.name ?? ''}
placeholder="Pick one"
data={modelData}
onChange={handleChangeModel}
tooltipProps={{ placement: 'top', hasArrow: true }}
/>
);
// return (
// <IAICustomSelect
// label={t('modelManager.model')}
// tooltip={selectedModel?.description}
// data={modelData}
// value={selectedModel?.name ?? ''}
// onChange={handleChangeModel}
// withCheckIcon={true}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// />
// );
};
export default memo(ModelSelect);

View File

@@ -13,19 +13,21 @@ import {
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { VALID_LOG_LEVELS } from 'app/logging/useLogger';
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISwitch from 'common/components/IAISwitch';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
SystemState,
consoleLogLevelChanged,
setEnableImageDebugging,
setShouldConfirmOnDelete,
setShouldDisplayGuides,
shouldAntialiasProgressImageChanged,
shouldLogToConsoleChanged,
SystemState,
} from 'features/system/store/systemSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import {
@@ -37,15 +39,13 @@ import { UIState } from 'features/ui/store/uiTypes';
import { isEqual } from 'lodash-es';
import {
ChangeEvent,
cloneElement,
ReactElement,
cloneElement,
useCallback,
useEffect,
} from 'react';
import { useTranslation } from 'react-i18next';
import { VALID_LOG_LEVELS } from 'app/logging/useLogger';
import { LogLevelName } from 'roarr';
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
import SettingsSchedulers from './SettingsSchedulers';
const selector = createSelector(
@@ -157,8 +157,8 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
}, [onSettingsModalClose, onRefreshModalOpen]);
const handleLogLevelChanged = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(consoleLogLevelChanged(e.target.value as LogLevelName));
(v: string) => {
dispatch(consoleLogLevelChanged(v as LogLevelName));
},
[dispatch]
);
@@ -255,14 +255,12 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
isChecked={shouldLogToConsole}
onChange={handleLogToConsoleChanged}
/>
<IAISelect
horizontal
spaceEvenly
isDisabled={!shouldLogToConsole}
<IAIMantineSelect
disabled={!shouldLogToConsole}
label={t('settings.consoleLogLevel')}
onChange={handleLogLevelChanged}
value={consoleLogLevel}
validValues={VALID_LOG_LEVELS.concat()}
data={VALID_LOG_LEVELS.concat()}
/>
<IAISwitch
label={t('settings.enableImageDebugging')}

View File

@@ -1,5 +1,4 @@
import {
Box,
Menu,
MenuButton,
MenuItemOption,
@@ -13,7 +12,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { setSchedulers } from 'features/ui/store/uiSlice';
import { isArray } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useTranslation } from 'react-i18next';
export default function SettingsSchedulers() {

View File

@@ -1,21 +1,21 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react';
import { Box, Flex } from '@chakra-ui/react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import InvokeAILogoComponent from 'features/system/components/InvokeAILogoComponent';
import OverlayScrollable from './common/OverlayScrollable';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import {
activeTabNameSelector,
uiSelector,
} from 'features/ui/store/uiSelectors';
import { setShouldShowParametersPanel } from 'features/ui/store/uiSlice';
import ResizableDrawer from './common/ResizableDrawer/ResizableDrawer';
import { memo, useMemo } from 'react';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import PinParametersPanelButton from './PinParametersPanelButton';
import TextToImageTabParameters from './tabs/TextToImage/TextToImageTabParameters';
import OverlayScrollable from './common/OverlayScrollable';
import ResizableDrawer from './common/ResizableDrawer/ResizableDrawer';
import ImageToImageTabParameters from './tabs/ImageToImage/ImageToImageTabParameters';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import TextToImageTabParameters from './tabs/TextToImage/TextToImageTabParameters';
import UnifiedCanvasParameters from './tabs/UnifiedCanvas/UnifiedCanvasParameters';
const selector = createSelector(

View File

@@ -1,11 +1,11 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { PropsWithChildren, memo } from 'react';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import OverlayScrollable from './common/OverlayScrollable';
import PinParametersPanelButton from './PinParametersPanelButton';
import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from '../store/uiSelectors';
import { useAppSelector } from 'app/store/storeHooks';
import PinParametersPanelButton from './PinParametersPanelButton';
import OverlayScrollable from './common/OverlayScrollable';
const selector = createSelector(uiSelector, (ui) => {
const { shouldPinParametersPanel, shouldShowParametersPanel } = ui;
@@ -35,19 +35,27 @@ const ParametersPinnedWrapper = (props: ParametersPinnedWrapperProps) => {
flexShrink: 0,
}}
>
<OverlayScrollable>
<Flex
sx={{
gap: 2,
flexDirection: 'column',
h: 'full',
w: 'full',
position: 'absolute',
}}
>
{props.children}
</Flex>
</OverlayScrollable>
<Flex
sx={{
gap: 2,
flexDirection: 'column',
h: 'full',
w: 'full',
position: 'absolute',
}}
>
<OverlayScrollable>
<Flex
sx={{
flexDirection: 'column',
gap: 2,
}}
>
{props.children}
</Flex>
</OverlayScrollable>
</Flex>
<PinParametersPanelButton
sx={{ position: 'absolute', top: 0, insetInlineEnd: 0 }}
/>

View File

@@ -1,6 +1,5 @@
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { PropsWithChildren, memo } from 'react';
const OverlayScrollable = (props: PropsWithChildren) => {
return (
<OverlayScrollbarsComponent
@@ -20,5 +19,4 @@ const OverlayScrollable = (props: PropsWithChildren) => {
</OverlayScrollbarsComponent>
);
};
export default memo(OverlayScrollable);

View File

@@ -1,14 +1,17 @@
import { Box, Flex } from '@chakra-ui/react';
import { memo, useCallback, useRef } from 'react';
import { Panel, PanelGroup } from 'react-resizable-panels';
import { useAppDispatch } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ResizeHandle from '../ResizeHandle';
import ImageToImageTabParameters from './ImageToImageTabParameters';
import TextToImageTabMain from '../TextToImage/TextToImageTabMain';
import { ImperativePanelGroupHandle } from 'react-resizable-panels';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
import { memo, useCallback, useRef } from 'react';
import {
ImperativePanelGroupHandle,
Panel,
PanelGroup,
} from 'react-resizable-panels';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import ResizeHandle from '../ResizeHandle';
import TextToImageTabMain from '../TextToImage/TextToImageTabMain';
import ImageToImageTabParameters from './ImageToImageTabParameters';
const ImageToImageTab = () => {
const dispatch = useAppDispatch();

View File

@@ -1,8 +1,8 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import TextToImageTabMain from './TextToImageTabMain';
import TextToImageTabParameters from './TextToImageTabParameters';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
const TextToImageTab = () => {
return (

View File

@@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
canvasSelector,
isStagingSelector,
@@ -12,7 +12,6 @@ import {
} from 'features/canvas/store/canvasTypes';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -51,22 +50,22 @@ export default function UnifiedCanvasLayerSelect() {
[layer]
);
const handleChangeLayer = (e: ChangeEvent<HTMLSelectElement>) => {
const newLayer = e.target.value as CanvasLayer;
const handleChangeLayer = (v: string) => {
const newLayer = v as CanvasLayer;
dispatch(setLayer(newLayer));
if (newLayer === 'mask' && !isMaskEnabled) {
dispatch(setIsMaskEnabled(true));
}
};
return (
<IAISelect
<IAIMantineSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
aria-label={`${t('unifiedCanvas.layer')} (Q)`}
tooltipProps={{ hasArrow: true, placement: 'top' }}
value={layer}
validValues={LAYER_NAMES_DICT}
data={LAYER_NAMES_DICT}
onChange={handleChangeLayer}
isDisabled={isStaging}
disabled={isStaging}
w="full"
/>
);
}

View File

@@ -2,6 +2,7 @@ import { Flex } from '@chakra-ui/react';
import IAICanvasRedoButton from 'features/canvas/components/IAICanvasToolbar/IAICanvasRedoButton';
import IAICanvasUndoButton from 'features/canvas/components/IAICanvasToolbar/IAICanvasUndoButton';
import UnifiedCanvasSettings from './UnifiedCanvasToolSettings/UnifiedCanvasSettings';
import UnifiedCanvasCopyToClipboard from './UnifiedCanvasToolbar/UnifiedCanvasCopyToClipboard';
import UnifiedCanvasDownloadImage from './UnifiedCanvasToolbar/UnifiedCanvasDownloadImage';
import UnifiedCanvasFileUploader from './UnifiedCanvasToolbar/UnifiedCanvasFileUploader';
@@ -13,11 +14,10 @@ import UnifiedCanvasResetCanvas from './UnifiedCanvasToolbar/UnifiedCanvasResetC
import UnifiedCanvasResetView from './UnifiedCanvasToolbar/UnifiedCanvasResetView';
import UnifiedCanvasSaveToGallery from './UnifiedCanvasToolbar/UnifiedCanvasSaveToGallery';
import UnifiedCanvasToolSelect from './UnifiedCanvasToolbar/UnifiedCanvasToolSelect';
import UnifiedCanvasSettings from './UnifiedCanvasToolSettings/UnifiedCanvasSettings';
const UnifiedCanvasToolbarBeta = () => {
return (
<Flex flexDirection="column" rowGap={2}>
<Flex flexDirection="column" rowGap={2} width="min-content">
<UnifiedCanvasLayerSelect />
<UnifiedCanvasToolSelect />

View File

@@ -1,8 +1,8 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import UnifiedCanvasContent from './UnifiedCanvasContent';
import UnifiedCanvasParameters from './UnifiedCanvasParameters';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
const UnifiedCanvasTab = () => {
return (

View File

@@ -0,0 +1,23 @@
import { MantineThemeOverride } from '@mantine/core';
export const mantineTheme: MantineThemeOverride = {
colorScheme: 'dark',
fontFamily: `'InterVariable', sans-serif`,
components: {
ScrollArea: {
defaultProps: {
scrollbarSize: 10,
},
styles: {
scrollbar: {
'&:hover': {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
thumb: {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
},
},
};

View File

@@ -24,9 +24,11 @@ export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation';
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
export type { FloatLinearRangeInvocation } from './models/FloatLinearRangeInvocation';
export type { FloatOutput } from './models/FloatOutput';
export type { Graph } from './models/Graph';
export type { GraphExecutionState } from './models/GraphExecutionState';
@@ -85,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
@@ -95,6 +98,7 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToImageInvocation } from './models/TextToImageInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';

View File

@@ -24,4 +24,3 @@ export type AddInvocation = {
*/
'b'?: number;
};

View File

@@ -5,4 +5,3 @@
export type Body_upload_image = {
file: Blob;
};

View File

@@ -30,4 +30,3 @@ export type CannyImageProcessorInvocation = {
*/
high_threshold?: number;
};

View File

@@ -29,4 +29,3 @@ export type CkptModelInfo = {
*/
height?: number;
};

View File

@@ -24,4 +24,3 @@ export type CollectInvocation = {
*/
collection?: Array<any>;
};

View File

@@ -12,4 +12,3 @@ export type CollectInvocationOutput = {
*/
collection: Array<any>;
};

View File

@@ -20,4 +20,3 @@ export type ColorField = {
*/
'a': number;
};

View File

@@ -24,4 +24,3 @@ export type CompelInvocation = {
*/
model?: string;
};

View File

@@ -14,4 +14,3 @@ export type CompelOutput = {
*/
conditioning?: ConditioningField;
};

View File

@@ -8,4 +8,3 @@ export type ConditioningField = {
*/
conditioning_name: string;
};

View File

@@ -42,4 +42,3 @@ export type ContentShuffleImageProcessorInvocation = {
*/
'f'?: number;
};

View File

@@ -16,7 +16,7 @@ export type ControlField = {
/**
* The weight given to the ControlNet
*/
control_weight: number;
control_weight: (number | Array<number>);
/**
* When the ControlNet is first applied (% of total steps)
*/
@@ -25,5 +25,8 @@ export type ControlField = {
* When the ControlNet is last applied (% of total steps)
*/
end_step_percent: number;
/**
* The contorl mode to use
*/
control_mode?: 'balanced' | 'more_prompt' | 'more_control' | 'unbalanced';
};

View File

@@ -22,13 +22,13 @@ export type ControlNetInvocation = {
*/
image?: ImageField;
/**
* The ControlNet model to use
* control model used
*/
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
/**
* The weight given to the ControlNet
*/
control_weight?: number;
control_weight?: (number | Array<number>);
/**
* When the ControlNet is first applied (% of total steps)
*/
@@ -37,5 +37,8 @@ export type ControlNetInvocation = {
* When the ControlNet is last applied (% of total steps)
*/
end_step_percent?: number;
/**
* The control mode used
*/
control_mode?: 'balanced' | 'more_prompt' | 'more_control' | 'unbalanced';
};

View File

@@ -10,8 +10,7 @@ import type { ControlField } from './ControlField';
export type ControlOutput = {
type?: 'control_output';
/**
* The output control image
* The control info
*/
control?: ControlField;
};

View File

@@ -15,4 +15,3 @@ export type CreateModelRequest = {
*/
info: (CkptModelInfo | DiffusersModelInfo);
};

View File

@@ -26,4 +26,3 @@ export type CvInpaintInvocation = {
*/
mask?: ImageField;
};

View File

@@ -23,4 +23,3 @@ export type DiffusersModelInfo = {
*/
path?: string;
};

View File

@@ -24,4 +24,3 @@ export type DivideInvocation = {
*/
'b'?: number;
};

View File

@@ -0,0 +1,30 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
*/
export type DynamicPromptInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'dynamic_prompt';
/**
* The prompt to parse with dynamicprompts
*/
prompt: string;
/**
* The number of prompts to generate
*/
max_prompts?: number;
/**
* Whether to use the combinatorial generator
*/
combinatorial?: boolean;
};

View File

@@ -14,4 +14,3 @@ export type Edge = {
*/
destination: EdgeConnection;
};

View File

@@ -12,4 +12,3 @@ export type EdgeConnection = {
*/
field: string;
};

View File

@@ -12,4 +12,3 @@ export type FloatCollectionOutput = {
*/
collection?: Array<number>;
};

View File

@@ -0,0 +1,30 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Creates a range
*/
export type FloatLinearRangeInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'float_range';
/**
* The first value of the range
*/
start?: number;
/**
* The last value of the range
*/
stop?: number;
/**
* number of values to interpolate over (including start and stop)
*/
steps?: number;
};

View File

@@ -12,4 +12,3 @@ export type FloatOutput = {
*/
param?: number;
};

View File

@@ -10,7 +10,9 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma
import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation';
import type { DynamicPromptInvocation } from './DynamicPromptInvocation';
import type { Edge } from './Edge';
import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation';
import type { GraphInvocation } from './GraphInvocation';
import type { HedImageProcessorInvocation } from './HedImageProcessorInvocation';
import type { ImageBlurInvocation } from './ImageBlurInvocation';
@@ -55,6 +57,7 @@ import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
import type { ShowImageInvocation } from './ShowImageInvocation';
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
import type { SubtractInvocation } from './SubtractInvocation';
import type { TextToImageInvocation } from './TextToImageInvocation';
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
@@ -69,10 +72,9 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
nodes?: Record<string, (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | DynamicPromptInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/
edges?: Array<Edge>;
};

View File

@@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput';
import type { LatentsOutput } from './LatentsOutput';
import type { MaskOutput } from './MaskOutput';
import type { NoiseOutput } from './NoiseOutput';
import type { PromptCollectionOutput } from './PromptCollectionOutput';
import type { PromptOutput } from './PromptOutput';
/**
@@ -45,7 +46,7 @@ export type GraphExecutionState = {
/**
* The results of node executions
*/
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
results: Record<string, (IntCollectionOutput | FloatCollectionOutput | CompelOutput | ImageOutput | MaskOutput | ControlOutput | LatentsOutput | NoiseOutput | IntOutput | FloatOutput | PromptOutput | PromptCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/**
* Errors raised when executing nodes
*/
@@ -59,4 +60,3 @@ export type GraphExecutionState = {
*/
source_prepared_mapping: Record<string, Array<string>>;
};

View File

@@ -22,4 +22,3 @@ export type GraphInvocation = {
*/
graph?: Graph;
};

View File

@@ -8,4 +8,3 @@
export type GraphInvocationOutput = {
type: 'graph_output';
};

View File

@@ -7,4 +7,3 @@ import type { ValidationError } from './ValidationError';
export type HTTPValidationError = {
detail?: Array<ValidationError>;
};

Some files were not shown because too many files have changed in this diff Show More