Compare commits

..

3 Commits

Author SHA1 Message Date
Lincoln Stein
3c50448ccf Merge branch 'main' into dev/pytorch2 2023-04-06 21:47:46 -04:00
Lincoln Stein
5dec5b6f51 Merge branch 'main' into dev/pytorch2 2023-03-23 23:31:21 -04:00
Kevin Turner
e158ad8534 deps: upgrade to PyTorch 2.0 (replaces xformers) 2023-03-15 15:45:48 -07:00
268 changed files with 1023 additions and 11105 deletions

19
.github/stale.yaml vendored
View File

@@ -1,19 +0,0 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 28
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 14
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
# Label to use when marking an issue as stale
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Please
update the ticket if this is still a problem on the latest release.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: >
Due to inactivity, this issue has been automatically closed. If this is
still a problem on the latest release, please recreate the issue.

View File

@@ -1,18 +1,10 @@
# Invocations
Invocations represent a single operation, its inputs, and its outputs. These
operations and their outputs can be chained together to generate and modify
images.
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
## Creating a new invocation
To create a new invocation, either find the appropriate module file in
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
that folder. All invocations in that folder will be discovered and made
available to the CLI and API automatically. Invocations make use of
[typing](https://docs.python.org/3/library/typing.html) and
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
into the CLI and API.
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
An invocation looks like this:
@@ -49,54 +41,34 @@ class UpscaleInvocation(BaseInvocation):
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring
that declares what they do in a single, short line. They should also have a
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
is what the user will type on the CLI or use in the API to create this
invocation. The `command_name` must be unique. The `type` must be assigned to
the value of the literal in the type hint.
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
| Part | Value | Description |
| ---- | ----- | ----------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
description, and validation information. For example:
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
| Part | Value | Description |
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
field to be parsed with `None` as a value, which enables linking to previous
invocations. All fields should either provide a default value or allow `None` as
a value, so that they can be overwritten with a linked output from another
invocation.
The special type `ImageField` is also used here. All images are passed as
`ImageField`, which protects them from pydantic validation errors (since images
only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
@@ -116,22 +88,13 @@ previous invocation by name and matching.
image = ImageField(image_type = image_type, image_name = image_name)
)
```
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
The `invoke` function is the last portion of an invocation. It is provided an
`InvocationContext` which contains services to perform work as well as a
`session_id` for use as needed. It should return a class with output values that
derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
Before being called, the invocation will have all of its fields set from
defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other
invocations, may be running on another machine, or in other interesting
scenarios. If you need functionality, please provide it as a service in the
`InvocationServices` class, and make sure it can be overridden.
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
@@ -139,64 +102,4 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image")
```
Output classes look like an invocation class without the invoke method. Prefer
to use an existing output class if available, and prefer to name inputs the same
as outputs when possible, to promote automatic invocation linking.
## Schema Generation
Invocation, output and related classes are used to generate an OpenAPI schema.
### Required Properties
The schema generation treat all properties with default values as optional. This
makes sense internally, but when when using these classes via the generated
schema, we end up with e.g. the `ImageOutput` class having its `image` property
marked as optional.
We know that this property will always be present, so the additional logic
needed to always check if the property exists adds a lot of extraneous cruft.
To fix this, we can leverage `pydantic`'s
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
to mark properties that we know will always be present as required.
Here's that `ImageOutput` class, without the needed schema customisation:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
```
The generated OpenAPI schema, and all clients/types generated from it, will have
the `type` and `image` properties marked as optional, even though we know they
will always have a value by the time we can interact with them via the API.
Here's the same class, but with the schema customisation added:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
```
The resultant schema (and any API client or types generated from it) will now
have see `type` as string literal `"image"` and `image` as an `ImageField`
object.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577>
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.

View File

@@ -50,7 +50,7 @@ subset that are currently installed are found in
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |

View File

@@ -461,8 +461,7 @@ def get_torch_source() -> (Union[str, None],str):
url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda':
url = 'https://download.pytorch.org/whl/cu117'
optional_modules = '[xformers]'
url = 'https://download.pytorch.org/whl/cu118'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@@ -1,14 +0,0 @@
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageMetadata = Field(description="The image's metadata")

View File

@@ -1,23 +1,18 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
import uuid
from fastapi import Path, Query, Request, UploadFile
from datetime import datetime, timezone
from fastapi import Path, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
@@ -43,60 +38,29 @@ async def get_thumbnail(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
},
status_code=201
)
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
async def upload_image(file: UploadFile, request: Request):
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code=415)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(image_path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD, image_name=filename
)
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page
},
)
return result

View File

@@ -1,17 +1,11 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@@ -21,9 +15,11 @@ class VaeRepo(BaseModel):
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
@@ -33,6 +29,7 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
@@ -40,29 +37,12 @@ class DiffusersModelInfo(ModelInfo):
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
@@ -74,61 +54,108 @@ async def list_models() -> ModelsList:
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
@models_router.post(
"/",
operation_id="update_model",
responses={200: {"status": "success"}},
)
async def update_model(
model_request: CreateModelRequest
) -> CreateModelResponse:
""" Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
clobber=True,
)
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
return model_response
# print(f">> Adding New Model: {model_name}")
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
@models_router.delete(
"/{model_name}",
operation_id="del_model",
responses={
204: {
"description": "Model deleted successfully"
},
404: {
"description": "Model not found"
}
},
)
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
model_exists = model_name in model_names
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# check if model exists
print(f">> Checking for model {model_name}...")
if model_exists:
print(f">> Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
print(f">> Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
@@ -248,4 +275,5 @@ async def delete_model(model_name: str) -> None:
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# except Exception as e:
# self.handle_exceptions(e)

View File

@@ -6,8 +6,7 @@ from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_t
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..models.image import ImageField
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints
from pydantic import BaseModel, Field
@@ -76,56 +76,3 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
#fmt: on
# TODO: figure out a better way to provide these hints
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
type_hints: Dict[
str,
Literal[
"integer",
"float",
"boolean",
"string",
"enum",
"image",
"latents",
"model",
],
]
tags: List[str]
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
`tags`
- A list of strings, used to categorise invocations.
`type_hints`
- A dict of field types which override the types in the invocation definition.
- Each key should be the name of one of the invocation's fields.
- Each value should be one of the valid types:
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
```python
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"initial_image": "image",
},
},
}
```
"""
schema_extra: CustomisedSchemaExtra

View File

@@ -5,26 +5,14 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class CvInvocationConfig(BaseModel):
"""Helper class to provide all OpenCV invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["cv", "image"],
},
}
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
#fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
@@ -56,9 +44,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -6,37 +6,21 @@ from typing import Literal, Optional, Union
import numpy as np
from torch import Tensor
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}
from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
]
# Text to image
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
type: Literal["txt2img"] = "txt2img"
@@ -50,7 +34,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
@@ -74,10 +58,16 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model_name = model["model_name"]
# def step_callback(state: PipelineIntermediateState):
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
# (right now uses whatever current model is set in model manager)
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context),
@@ -96,22 +86,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[self.id]
invocation = graph_execution_state.execution_graph.get_node(self.id)
metadata = {
"session": context.graph_execution_state_id,
"source_id": source_id,
"invocation": invocation.dict()
}
context.services.images.save(image_type, image_name, generate_output.image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@@ -157,9 +134,9 @@ class ImageToImageInvocation(TextToImageInvocation):
mask = None
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
@@ -183,11 +160,9 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
@@ -235,9 +210,9 @@ class InpaintInvocation(ImageToImageInvocation):
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
@@ -261,9 +236,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -7,90 +7,65 @@ import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
#fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
#fmt: on
class Config:
schema_extra = {
"required": [
"type",
"image",
"width",
"height",
'required': [
'type',
'image',
]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
image_field = ImageField(image_name=image_name, image_type=image_type)
return ImageOutput(image=image_field, width=image.width, height=image.height)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
#fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
# fmt: on
#fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
type: Literal["load_image"] = "load_image"
# # TODO: this isn't really necessary anymore
# class LoadImageInvocation(BaseInvocation):
# """Load an image from a filename and provide it as output."""
# #fmt: off
# type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# # Inputs
# image_type: ImageType = Field(description="The type of the image")
# image_name: str = Field(description="The name of the image")
# #fmt: on
# def invoke(self, context: InvocationContext) -> ImageOutput:
# return ImageOutput(
# image_type=self.image_type,
# image_name=self.image_name,
# image=result_image
# )
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
class ShowImageInvocation(BaseInvocation):
@@ -110,17 +85,16 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
class CropImageInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
#fmt: off
type: Literal["crop"] = "crop"
# Inputs
@@ -129,7 +103,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -145,16 +119,15 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=image_crop
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
class PasteImageInvocation(BaseInvocation):
"""Pastes an image into another image."""
# fmt: off
#fmt: off
type: Literal["paste"] = "paste"
# Inputs
@@ -163,7 +136,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@@ -176,7 +149,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
None
if self.mask is None
else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name)
services.images.get(self.mask.image_type, self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@@ -196,22 +169,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=new_image
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
# fmt: off
#fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@@ -226,22 +198,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask, self.dict())
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
class BlurInvocation(BaseInvocation):
"""Blurs an image"""
# fmt: off
#fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -258,23 +230,22 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
class LerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -290,24 +261,23 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class InverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -327,7 +297,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -2,24 +2,24 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
@@ -109,17 +109,8 @@ class NoiseInvocation(BaseInvocation):
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
device = torch.device(CUDA_DEVICE)
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
@@ -145,50 +136,46 @@ class TextToLatentsInvocation(BaseInvocation):
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
(width, height) = image.size
width *= 8
height *= 8
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
scheduler_name=self.sampler_name
)
if isinstance(model, DiffusionPipeline):
@@ -227,7 +214,7 @@ class TextToLatentsInvocation(BaseInvocation):
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@@ -257,17 +244,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
}
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
@@ -277,7 +253,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@@ -323,23 +299,12 @@ class LatentsToImageInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
@@ -350,9 +315,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,22 +1,15 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
class MathInvocationConfig(BaseModel):
"""Helper class to provide all math invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["math"],
}
}
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class IntOutput(BaseInvocationOutput):
@@ -27,7 +20,7 @@ class IntOutput(BaseInvocationOutput):
#fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
@@ -39,7 +32,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
@@ -51,7 +44,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
@@ -63,7 +56,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation, MathInvocationConfig):
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"

View File

@@ -3,10 +3,10 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@@ -18,14 +18,6 @@ class RestoreFaceInvocation(BaseInvocation):
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -44,9 +36,7 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -5,10 +5,10 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class UpscaleInvocation(BaseInvocation):
@@ -22,15 +22,6 @@ class UpscaleInvocation(BaseInvocation):
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -49,9 +40,7 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,14 +0,0 @@
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model

View File

@@ -1,3 +0,0 @@
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@@ -1,26 +0,0 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: ImageType = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
"required": [
"image_type",
"image_name",
]
}

View File

@@ -1,26 +0,0 @@
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session: Optional[str] = Field(description="The session that generated this image")
source_id: Optional[str] = Field(
description="The source id of the invocation that generated this image"
)
# TODO: figure out metadata
invocation: Optional[Dict[str, Any]] = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's general metadata"""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
default={}, description="The image's InvokeAI-specific metadata"
)

View File

@@ -25,8 +25,7 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_dict: dict,
source_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@@ -36,8 +35,7 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
total_steps=total_steps,
@@ -45,43 +43,40 @@ class EventServiceBase:
)
def emit_invocation_complete(
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
self, graph_execution_state_id: str, invocation_id: str, result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
self, graph_execution_state_id: str, invocation_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
self, graph_execution_state_id: str, invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
),
)

View File

@@ -794,6 +794,9 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [

View File

@@ -2,26 +2,24 @@
import datetime
import os
import json
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, List, Union
from typing import Dict
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel, Json
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@@ -29,21 +27,13 @@ class ImageStorageBase(ABC):
def get(self, image_type: ImageType, image_name: str) -> Image:
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
def get_path(self, image_type: ImageType, image_name: str) -> str:
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
pass
@abstractmethod
@@ -81,84 +71,25 @@ class DiskImageStorage(ImageStorageBase):
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{image_type.value}/{filename}",
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", image_name
)
else:
path = os.path.join(self.__output_folder, image_type, image_name)
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
print(metadata)
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, metadata
image, "", image_subpath, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
@@ -167,23 +98,15 @@ class DiskImageStorage(ImageStorageBase):
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
return image_path
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)
thumbnail_path = self.get_path(image_type, image_name, True)
if os.path.exists(image_path):
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
if os.path.exists(thumbnail_path):
os.remove(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]

View File

@@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
@@ -43,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node to provide to cliepnts (the prepared node is not as useful)
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id
invocation_id=invocation.id,
)
# Invoke
@@ -79,8 +75,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id,
invocation_id=invocation.id,
result=outputs.dict(),
)
@@ -104,8 +99,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id,
invocation_id=invocation.id,
error=error,
)

View File

@@ -1,25 +1,23 @@
import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
from sqlalchemy import create_engine, String, TEXT, Engine, select
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session
T = TypeVar("T", bound=BaseModel)
class Base(DeclarativeBase):
pass
sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_engine: Engine
# _table: ??? # TODO: figure out how to type this
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__()
@@ -27,79 +25,86 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._engine = create_engine(f"sqlite+pysqlite:///{self._filename}")
self._create_table()
def _create_table(self):
# dynamically create the ORM model class to avoid name collisions
# cannot access `self.__orig_class__` in `__init__` or `__new__` so
# format the table name into the class name
pascal_table_name = self._table_name.replace("_", " ").title()
pascal_table_name = pascal_table_name.replace(" ", "")
table_dict = dict(
__tablename__=self._table_name,
id=mapped_column(String, primary_key=True),
item=mapped_column(TEXT, nullable=False),
)
self._table = type(pascal_table_name, (Base,), table_dict)
Base.metadata.create_all(self._engine)
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
session = Session(self._engine)
item_id = str(getattr(item, self._id_field))
new_item = self._table(id=item_id, item=item.json())
session.merge(new_item)
session.commit()
session.close()
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
session = Session(self._engine)
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
item = session.get(self._table, str(id))
session.close()
if not item:
if not result:
return None
return self._parse_item(item.item)
return self._parse_item(result[0])
def delete(self, id: str):
session = Session(self._engine)
item = session.get(self._table, id)
session.delete(item)
session.commit()
session.close()
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
session = Session(self._engine)
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
stmt = select(self._table.item).limit(per_page).offset(page * per_page)
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0]), result))
items = list(map(lambda r: self._parse_item(r[0]), result))
count = session.query(self._table.item).count()
session.commit()
session.close()
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
@@ -110,19 +115,23 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
session = Session(self._engine)
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
stmt = (
session.query(self._table)
.where(self._table.item.like(f"%{query}%"))
.limit(per_page)
.offset(page * per_page)
)
items = list(map(lambda r: self._parse_item(r[0]), result))
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0].item), result))
count = session.query(self._table.item).count()
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1

View File

@@ -1,17 +1,14 @@
from re import S
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
def fast_latents_step_callback(
sample: torch.Tensor,
step: int,
steps: int,
id: str,
context: InvocationContext,
):
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
@@ -21,21 +18,18 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[id]
invocation = graph_execution_state.execution_graph.get_node(id)
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
invocation_dict=invocation.dict(),
source_id=source_id,
progress_image={"width": width, "height": height, "dataURL": dataURL},
step=step,
total_steps=steps,
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
@@ -43,8 +37,6 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@@ -561,7 +561,7 @@ class Args(object):
"--autoimport",
default=None,
type=str,
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
)
model_group.add_argument(
"--autoconvert",

View File

@@ -67,6 +67,7 @@ def install_requested_models(
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
convert_to_diffusers: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
@@ -112,6 +113,7 @@ def install_requested_models(
try:
model_manager.heuristic_import(
path_url_or_repo,
convert=convert_to_diffusers,
commit_to_conf=config_file_path,
)
except KeyboardInterrupt:
@@ -120,7 +122,7 @@ def install_requested_models(
pass
if scan_at_startup and scan_directory.is_dir():
argument = "--autoconvert"
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f"{Globals.initfile}.new")
directory = str(scan_directory).replace("\\", "/")

View File

@@ -41,7 +41,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt)
if metadata:
info.add_text("invokeai", json.dumps(metadata))
info.add_text("sd-metadata", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
return path

View File

@@ -7,4 +7,3 @@ from .convert_ckpt_to_diffusers import (
)
from .model_manager import ModelManager

View File

@@ -1,4 +1,4 @@
"""enum
"""
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
@@ -15,7 +15,7 @@ import sys
import textwrap
import time
import warnings
from enum import Enum, auto
from enum import Enum
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union, Callable
@@ -24,12 +24,8 @@ import safetensors
import safetensors.torch
import torch
import transformers
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@@ -37,58 +33,37 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPFeatureExtractor,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
V2 = auto()
V2_e = auto()
V2_v = auto()
UNKNOWN = auto()
V1 = 1
V1_INPAINT = 2
V2 = 3
V2_e = 4
V2_v = 5
UNKNOWN = 99
class SDModelComponent(Enum):
vae="vae"
text_encoder="text_encoder"
tokenizer="tokenizer"
unet="unet"
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
class ModelManager(object):
"""
'''
Model manager handles loading, caching, importing, deleting, converting, and editing models.
"""
'''
def __init__(
self,
config: OmegaConf | Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path = None,
self,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path=None,
):
"""
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
@@ -112,25 +87,15 @@ class ModelManager(object):
"""
return model_name in self.config
def get_model(self, model_name: str = None) -> dict:
"""Given a model named identified in models.yaml, return a dict
containing the model object and some of its key features. If
in RAM will load into GPU VRAM. If on disk, will load from
there.
The dict has the following keys:
'model': The StableDiffusionGeneratorPipeline object
'model_name': The name of the model in models.yaml
'width': The width of images trained by this model
'height': The height of images trained by this model
'hash': A unique hash of this model's files on disk.
def get_model(self, model_name: str=None)->dict:
"""
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
"""
if not model_name:
return (
self.get_model(self.current_model)
if self.current_model
else self.get_model(self.default_model())
)
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
@@ -170,81 +135,6 @@ class ModelManager(object):
"hash": hash,
}
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned VAE as an
AutoencoderKL object. If no model name is provided, return the
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
model name is provided, return the tokenizer from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTextModel. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned scheduler. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
model_part: SDModelComponent=SDModelComponent.vae,
) -> Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
UNet2DConditionModel,
CLIPTextModel,
StableDiffusionSafetyChecker,
]:
"""Given a model name identified in models.yaml, and the part of the
model you wish to retrieve, return that part. Parts are in an Enum
class named SDModelComponent, and consist of:
SDModelComponent.vae
SDModelComponent.text_encoder
SDModelComponent.tokenizer
SDModelComponent.unet
SDModelComponent.scheduler
SDModelComponent.safety_checker
SDModelComponent.feature_extractor
"""
model_dict = self.get_model(model_name)
model = model_dict["model"]
return getattr(model, model_part.value)
def default_model(self) -> str | None:
"""
Returns the name of the default model, or None
@@ -470,7 +360,7 @@ class ModelManager(object):
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
print(">> Model loaded in", "%4.2fs" % (toc - tic))
@@ -543,7 +433,7 @@ class ModelManager(object):
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
@@ -564,18 +454,14 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt
try:
if self.list_models()[self.current_model]["status"] == "active":
if self.list_models()[self.current_model]['status'] == 'active':
self.offload_model(self.current_model)
except Exception as e:
pass
vae_path = None
if vae:
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -685,7 +571,9 @@ class ModelManager(object):
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_description = description or f"Imported diffusers model {model_name}"
model_description = (
description or f"Imported diffusers model {model_name}"
)
new_config = dict(
description=model_description,
vae=vae,
@@ -714,7 +602,7 @@ class ModelManager(object):
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
SDLegacyType.UNKNOWN
"""
global_step = checkpoint.get("global_step")
global_step = checkpoint.get('global_step')
state_dict = checkpoint.get("state_dict") or checkpoint
try:
@@ -740,13 +628,13 @@ class ModelManager(object):
return SDLegacyType.UNKNOWN
def heuristic_import(
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str:
"""Accept a string which could be:
- a HF diffusers repo_id
@@ -850,8 +738,8 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = None
if model_path.suffix in [".ckpt", ".pt"]:
self.scan_model(model_path, model_path)
if model_path.suffix in [".ckpt",".pt"]:
self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
@@ -873,16 +761,19 @@ class ModelManager(object):
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(" | SD-v2-v model detected")
print(
" | SD-v2-v model detected"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
print(" | SD-v2-e model detected")
print(
" | SD-v2-e model detected"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
@@ -929,16 +820,16 @@ class ModelManager(object):
return model_name
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae: dict = None,
vae_path: Path = None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool = True,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@@ -966,10 +857,10 @@ class ModelManager(object):
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
vae_model=None
if vae:
vae_model = self._load_vae(vae)
vae_path = None
vae_model=self._load_vae(vae)
vae_path=None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
@@ -1085,16 +976,16 @@ class ModelManager(object):
legacy_locations = [
Path(
models_dir,
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
),
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
Path(
models_dir,
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
),
]
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@@ -1112,7 +1003,7 @@ class ModelManager(object):
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
>> Otherwise press <enter> to continue."""
)
input("continue> ")
input('continue> ')
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
@@ -1199,12 +1090,12 @@ class ModelManager(object):
print(
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
)
def _has_cuda(self) -> bool:
return self.device.type == "cuda"
def _diffuser_sha256(
self, name_or_path: Union[str, Path], chunksize=16777216
self, name_or_path: Union[str, Path], chunksize=4096
) -> Union[str, bytes]:
path = None
if isinstance(name_or_path, Path):

View File

@@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None,
additional_guidance: List[Callable] = None,
):
self._adjust_memory_efficient_attention(latents)
# FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
# self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:

View File

@@ -158,9 +158,14 @@ def main():
report_model_error(opt, e)
# try to autoconvert new models
if path := opt.autoimport:
gen.model_manager.heuristic_import(
str(path), convert=False, commit_to_conf=opt.conf
)
if path := opt.autoconvert:
gen.model_manager.heuristic_import(
str(path), commit_to_conf=opt.conf
str(path), convert=True, commit_to_conf=opt.conf
)
# web server loops forever
@@ -576,7 +581,6 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!replay"):
file_path = command.replace("!replay", "", 1).strip()
file_path = os.path.join(opt.outdir, file_path)
if infile is None and os.path.isfile(file_path):
infile = open(file_path, "r", encoding="utf-8")
completer.add_history(command)

View File

@@ -199,6 +199,17 @@ class addModelsForm(npyscreen.FormMultiPage):
relx=4,
scroll_exit=True,
)
self.nextrely += 1
self.convert_models = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
values=["Keep original format", "Convert to diffusers"],
value=0,
begin_entry_at=4,
max_height=4,
hidden=True, # will appear when imported models box is edited
scroll_exit=True,
)
self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="CANCEL",
@@ -233,6 +244,8 @@ class addModelsForm(npyscreen.FormMultiPage):
self.show_directory_fields.addVisibleWhenSelected(i)
self.show_directory_fields.when_value_edited = self._clear_scan_directory
self.import_model_paths.when_value_edited = self._show_hide_convert
self.autoload_directory.when_value_edited = self._show_hide_convert
def resize(self):
super().resize()
@@ -243,6 +256,13 @@ class addModelsForm(npyscreen.FormMultiPage):
if not self.show_directory_fields.value:
self.autoload_directory.value = ""
def _show_hide_convert(self):
model_paths = self.import_model_paths.value or ""
autoload_directory = self.autoload_directory.value or ""
self.convert_models.hidden = (
len(model_paths) == 0 and len(autoload_directory) == 0
)
def _get_starter_model_labels(self) -> List[str]:
window_width, window_height = get_terminal_size()
label_width = 25
@@ -302,6 +322,7 @@ class addModelsForm(npyscreen.FormMultiPage):
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
"""
# we're using a global here rather than storing the result in the parentapp
# due to some bug in npyscreen that is causing attributes to be lost
@@ -338,6 +359,7 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like
selections.import_model_paths = self.import_model_paths.value.split()
selections.convert_to_diffusers = self.convert_models.value[0] == 1
class AddModelApplication(npyscreen.NPSAppManaged):
@@ -350,6 +372,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
def onStart(self):
@@ -370,6 +393,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
convert_to_diffusers = selections.convert_to_diffusers
install_requested_models(
install_initial_models=models_to_install,
@@ -377,6 +401,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
convert_to_diffusers=convert_to_diffusers,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),

View File

@@ -6,5 +6,3 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@@ -3,8 +3,4 @@ dist/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@@ -1,16 +1,10 @@
# InvokeAI Web UI
- [InvokeAI Web UI](#invokeai-web-ui)
- [Stack](#stack)
- [Contributing](#contributing)
- [Dev Environment](#dev-environment)
- [Production builds](#production-builds)
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
Code in `invokeai/frontend/web/` if you want to have a look.
## Stack
## Details
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
@@ -38,7 +32,7 @@ Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
### Production builds

View File

@@ -1,87 +0,0 @@
# Generated axios API client
- [Generated axios API client](#generated-axios-api-client)
- [Generation](#generation)
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
- [Generate the API client from JSON](#generate-the-api-client-from-json)
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
- [Generate the API client](#generate-the-api-client)
- [The generated client](#the-generated-client)
- [API client customisation](#api-client-customisation)
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
## Generation
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
### Generate the API client from the nodes web server
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
1. Start the nodes web server.
```bash
# from the repo root
python scripts/invoke-new.py --web
```
2. Generate the API client.
```bash
# from invokeai/frontend/web/
yarn api:web
```
### Generate the API client from JSON
The JSON can be acquired from the nodes web server, or with a python script.
#### Getting the JSON from the nodes web server
Start the nodes web server as described above, then download the file.
```bash
# from invokeai/frontend/web/
curl http://localhost:9090/openapi.json -o openapi.json
```
#### Getting the JSON with a python script
Run this python script from the repo root, so it can access the nodes server modules.
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
```bash
# from the repo root
python invokeai/app/util/generate_openapi_json.py
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
```
#### Generate the API client
Now we can generate the API client from the JSON.
```bash
# from invokeai/frontend/web/
yarn api:file
```
## The generated client
The client will be written to `invokeai/frontend/web/services/api/`:
- `axios` client
- TS types
- An easily parseable schema, which we can use to generate UI
## API client customisation
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.

View File

@@ -1,21 +0,0 @@
# Events
Events via `socket.io`
## `actions.ts`
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
Any reducer (or middleware) can respond to the actions.
## `middleware.ts`
Redux middleware for events.
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
## `types.ts`
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.

View File

@@ -1,17 +0,0 @@
# Node Editor Design
WIP
nodes
everything in `src/features/nodes/`
have a look at `state.nodes.invocation`
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
- `reactflow` sends changes to nodes/edges to redux
- to invoke, `buildNodesGraph()` state, then send this
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`

View File

@@ -1,29 +0,0 @@
# Package Scripts
WIP walkthrough of `package.json` scripts.
## `theme` & `theme:watch`
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
The CLI essentially monkeypatches Chakra's files in `node_modules`.
## `postinstall`
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@@ -1,7 +1,6 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
@@ -65,24 +64,9 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
}
declare function Invoke(props: InvokeProps): JSX.Element;
declare function Invoke(props: PropsWithChildren): JSX.Element;
export {
ThemeChanger,
@@ -90,7 +74,5 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@@ -5,10 +5,7 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@@ -44,10 +41,9 @@
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@reduxjs/toolkit": "^1.9.3",
"@reduxjs/toolkit": "^1.9.2",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
@@ -71,9 +67,7 @@
"react-redux": "^8.0.5",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
@@ -89,7 +83,6 @@
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
@@ -97,17 +90,13 @@
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",

View File

@@ -522,10 +522,6 @@
"resetComplete": "Web UI has been reset. Refresh the page to reload."
},
"toast": {
"serverError": "Server Error",
"disconnected": "Disconnected from Server",
"connected": "Connected to Server",
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",

View File

@@ -13,42 +13,16 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { useAppSelector } from './storeHooks';
import { PropsWithChildren, useEffect } from 'react';
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
keepGUIAlive();
interface Props extends PropsWithChildren {
options: {
disabledPanels: string[];
disabledTabs: InvokeTabName[];
shouldTransformUrls?: boolean;
};
}
const App = (props: Props) => {
const App = (props: PropsWithChildren) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(setDisabledPanels(props.options.disabledPanels));
}, [dispatch, props.options.disabledPanels]);
useEffect(() => {
dispatch(setDisabledTabs(props.options.disabledTabs));
}, [dispatch, props.options.disabledTabs]);
useEffect(() => {
dispatch(
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
);
}, [dispatch, props.options.shouldTransformUrls]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');

View File

@@ -14,8 +14,6 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:
@@ -115,7 +113,7 @@ export declare type Metadata = SystemGenerationMetadata & {
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type _Image = {
export declare type Image = {
uuid: string;
url: string;
thumbnail: string;
@@ -126,23 +124,11 @@ export declare type _Image = {
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
/**
* ResultImage
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<_Image>;
images: Array<Image>;
};
/**
@@ -289,7 +275,7 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@@ -310,7 +296,7 @@ export declare type ErrorResponse = {
};
export declare type GalleryImagesResponse = {
images: Array<Omit<_Image, 'uuid'>>;
images: Array<Omit<Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};

View File

@@ -13,13 +13,9 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);

View File

@@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const {
@@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const {
@@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);

View File

@@ -34,9 +34,8 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
// setInitialImage,
setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@@ -147,8 +146,7 @@ const makeSocketIOListeners = (
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(initialImageSelected(newImage.uuid));
// dispatch(setInitialImage(newImage));
dispatch(setInitialImage(newImage));
break;
}
}
@@ -264,7 +262,7 @@ const makeSocketIOListeners = (
*/
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI._Image => {
const preparedImages = images.map((image): InvokeAI.Image => {
return {
uuid: uuidv4(),
...image,
@@ -336,7 +334,7 @@ const makeSocketIOListeners = (
if (
initialImage === url ||
(initialImage as InvokeAI._Image)?.url === url
(initialImage as InvokeAI.Image)?.url === url
) {
dispatch(clearInitialImage());
}

View File

@@ -29,8 +29,6 @@ export const socketioMiddleware = () => {
path: `${window.location.pathname}socket.io`,
});
socketio.disconnect();
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {

View File

@@ -2,35 +2,18 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer, {
GalleryState,
} from 'features/gallery/store/gallerySlice';
import resultsReducer, {
resultsAdapter,
ResultsState,
} from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer, {
LightboxState,
} from 'features/lightbox/store/lightboxSlice';
import generationReducer, {
GenerationState,
} from 'features/parameters/store/generationSlice';
import postprocessingReducer, {
PostprocessingState,
} from 'features/parameters/store/postprocessingSlice';
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { CanvasState } from 'features/canvas/store/canvasTypes';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@@ -46,21 +29,13 @@ import { CanvasState } from 'features/canvas/store/canvasTypes';
* The necesssary nested persistors with blacklists are configured below.
*/
/**
* Canvas slice persist blacklist
*/
const canvasBlacklist: (keyof CanvasState)[] = [
const canvasBlacklist = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
].map((blacklistItem) => `canvas.${blacklistItem}`);
canvasBlacklist.map((blacklistItem) => `canvas.${blacklistItem}`);
/**
* System slice persist blacklist
*/
const systemBlacklist: (keyof SystemState)[] = [
const systemBlacklist = [
'currentIteration',
'currentStatus',
'currentStep',
@@ -73,101 +48,30 @@ const systemBlacklist: (keyof SystemState)[] = [
'totalIterations',
'totalSteps',
'openModel',
'isCancelScheduled',
'sessionId',
'progressImage',
];
'cancelOptions.cancelAfter',
].map((blacklistItem) => `system.${blacklistItem}`);
systemBlacklist.map((blacklistItem) => `system.${blacklistItem}`);
/**
* Gallery slice persist blacklist
*/
const galleryBlacklist: (keyof GalleryState)[] = [
const galleryBlacklist = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
].map((blacklistItem) => `gallery.${blacklistItem}`);
galleryBlacklist.map((blacklistItem) => `gallery.${blacklistItem}`);
/**
* Lightbox slice persist blacklist
*/
const lightboxBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
lightboxBlacklist.map((blacklistItem) => `lightbox.${blacklistItem}`);
/**
* Nodes slice persist blacklist
*/
const nodesBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
nodesBlacklist.map((blacklistItem) => `nodes.${blacklistItem}`);
/**
* Generation slice persist blacklist
*/
const generationBlacklist: (keyof GenerationState)[] = [];
generationBlacklist.map((blacklistItem) => `generation.${blacklistItem}`);
/**
* Postprocessing slice persist blacklist
*/
const postprocessingBlacklist: (keyof PostprocessingState)[] = [];
postprocessingBlacklist.map(
(blacklistItem) => `postprocessing.${blacklistItem}`
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
);
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config below
*/
const resultsBlacklist: (keyof ResultsState)[] = [];
resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config below
*/
const uploadsBlacklist: (keyof NodesState)[] = [];
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
/**
* Models slice persist blacklist
*/
const modelsBlacklist: (keyof NodesState)[] = [];
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
/**
* UI slice persist blacklist
*/
const uiBlacklist: (keyof NodesState)[] = [];
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
const rootReducer = combineReducers({
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
uploads: uploadsReducer,
lightbox: lightboxReducer,
});
const rootPersistConfig = getPersistConfig({
@@ -176,40 +80,23 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
...galleryBlacklist,
...lightboxBlacklist,
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(dynamicMiddlewares),
}).concat(socketioMiddleware()),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [

View File

@@ -1,8 +0,0 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@@ -2,6 +2,7 @@ import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';
import {
@@ -14,7 +15,6 @@ import {
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
type ImageUploaderProps = {
@@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(imageUploaded({ formData: { file } }));
dispatch(uploadImage({ imageFile: file }));
},
[dispatch]
);
@@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(imageUploaded({ formData: { file } }));
dispatch(uploadImage({ imageFile: file }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@@ -1,160 +1,27 @@
// import WorkInProgress from './WorkInProgress';
// import ReactFlow, {
// applyEdgeChanges,
// applyNodeChanges,
// Background,
// Controls,
// Edge,
// Handle,
// Node,
// NodeTypes,
// OnEdgesChange,
// OnNodesChange,
// Position,
// } from 'reactflow';
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
// import 'reactflow/dist/style.css';
// import {
// Fragment,
// FunctionComponent,
// ReactNode,
// useCallback,
// useMemo,
// useState,
// } from 'react';
// import { OpenAPIV3 } from 'openapi-types';
// import { filter, map, reduce } from 'lodash';
// import {
// Box,
// Flex,
// FormControl,
// FormLabel,
// Input,
// Select,
// Switch,
// Text,
// NumberInput,
// NumberInputField,
// NumberInputStepper,
// NumberIncrementStepper,
// NumberDecrementStepper,
// Tooltip,
// chakra,
// Badge,
// Heading,
// VStack,
// HStack,
// Menu,
// MenuButton,
// MenuList,
// MenuItem,
// MenuItemOption,
// MenuGroup,
// MenuOptionGroup,
// MenuDivider,
// IconButton,
// } from '@chakra-ui/react';
// import { FaPlus } from 'react-icons/fa';
// import {
// FIELD_NAMES as FIELD_NAMES,
// FIELDS,
// INVOCATION_NAMES as INVOCATION_NAMES,
// INVOCATIONS,
// } from 'features/nodeEditor/constants';
// console.log('invocations', INVOCATIONS);
// const nodeTypes = reduce(
// INVOCATIONS,
// (acc, val, key) => {
// acc[key] = val.component;
// return acc;
// },
// {} as NodeTypes
// );
// console.log('nodeTypes', nodeTypes);
// // make initial nodes one of every node for now
// let n = 0;
// const initialNodes = map(INVOCATIONS, (i) => ({
// id: i.type,
// type: i.title,
// position: { x: (n += 20), y: (n += 20) },
// data: {},
// }));
// console.log('initialNodes', initialNodes);
// export default function NodesWIP() {
// const [nodes, setNodes] = useState<Node[]>([]);
// const [edges, setEdges] = useState<Edge[]>([]);
// const onNodesChange: OnNodesChange = useCallback(
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
// []
// );
// const onEdgesChange: OnEdgesChange = useCallback(
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
// []
// );
// return (
// <Box
// sx={{
// position: 'relative',
// width: 'full',
// height: 'full',
// borderRadius: 'md',
// }}
// >
// <ReactFlow
// nodeTypes={nodeTypes}
// nodes={nodes}
// edges={edges}
// onNodesChange={onNodesChange}
// onEdgesChange={onEdgesChange}
// >
// <Background />
// <Controls />
// </ReactFlow>
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
// {FIELD_NAMES.map((field) => (
// <Badge
// key={field}
// colorScheme={FIELDS[field].color}
// sx={{ userSelect: 'none' }}
// >
// {field}
// </Badge>
// ))}
// </HStack>
// <Menu>
// <MenuButton
// as={IconButton}
// aria-label="Options"
// icon={<FaPlus />}
// sx={{ position: 'absolute', top: 2, left: 2 }}
// />
// <MenuList>
// {INVOCATION_NAMES.map((name) => {
// const invocation = INVOCATIONS[name];
// return (
// <Tooltip
// key={name}
// label={invocation.description}
// placement="end"
// hasArrow
// >
// <MenuItem>{invocation.title}</MenuItem>
// </Tooltip>
// );
// })}
// </MenuList>
// </Menu>
// </Box>
// );
// }
export default {};
export default function NodesWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}

View File

@@ -14,8 +14,6 @@ const WorkInProgress = (props: WorkInProgressProps) => {
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}

View File

@@ -1,72 +0,0 @@
import { RootState } from 'app/store';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { find } from 'lodash';
import {
Graph,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildHiResNode, buildImg2ImgNode } from './nodes/image2Image';
import { buildIteration } from './nodes/iteration';
import { buildTxt2ImgNode } from './nodes/text2Image';
function mapTabToFunction(activeTabName: InvokeTabName) {
switch (activeTabName) {
case 'txt2img':
return buildTxt2ImgNode;
case 'img2img':
return buildImg2ImgNode;
default:
return buildTxt2ImgNode;
}
}
const buildBaseNode = (
state: RootState
): Record<string, TextToImageInvocation | ImageToImageInvocation> => {
const { activeTab } = state.ui;
const activeTabName = tabMap[activeTab];
return mapTabToFunction(activeTabName)(state);
};
type BuildGraphOutput = {
graph: Graph;
nodeIdsToSubscribe: string[];
};
export const buildGraph = (state: RootState): BuildGraphOutput => {
const { generation, postprocessing } = state;
const { iterations } = generation;
const { hiresFix, hiresStrength } = postprocessing;
const baseNode = buildBaseNode(state);
let graph: Graph = { nodes: baseNode };
const nodeIdsToSubscribe: string[] = [];
if (iterations > 1) {
graph = buildIteration({ graph, iterations });
}
if (hiresFix) {
const { node, edge } = buildHiResNode(
baseNode as Record<string, TextToImageInvocation>,
hiresStrength
);
graph = {
nodes: {
...graph.nodes,
...node,
},
edges: [...(graph.edges || []), edge],
};
nodeIdsToSubscribe.push(Object.keys(node)[0]);
}
console.log('buildGraph: ', graph);
return { graph, nodeIdsToSubscribe };
};

View File

@@ -1,6 +0,0 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@@ -1,28 +0,0 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
};
};

View File

@@ -1,98 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { _Image } from 'app/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
export const buildImg2ImgNode = (
state: RootState
): Record<string, ImageToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale,
sampler,
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
}
return {
[nodeId]: {
id: nodeId,
type: 'img2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
},
};
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@@ -1,81 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import {
Edge,
Graph,
ImageToImageInvocation,
IterateInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildImg2ImgNode } from './image2Image';
type BuildIteration = {
graph: Graph;
iterations: number;
};
const buildRangeNode = (
iterations: number
): Record<string, RangeInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'range',
start: 0,
stop: iterations,
step: 1,
},
};
};
const buildIterateNode = (): Record<string, IterateInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'iterate',
collection: [],
index: 0,
},
};
};
export const buildIteration = ({
graph,
iterations,
}: BuildIteration): Graph => {
const rangeNode = buildRangeNode(iterations);
const iterateNode = buildIterateNode();
const baseNode: Graph['nodes'] = graph.nodes;
const edges: Edge[] = [
{
source: {
field: 'collection',
node_id: Object.keys(rangeNode)[0],
},
destination: {
field: 'collection',
node_id: Object.keys(iterateNode)[0],
},
},
{
source: {
field: 'item',
node_id: Object.keys(iterateNode)[0],
},
destination: {
field: 'seed',
node_id: Object.keys(baseNode!)[0],
},
},
];
return {
nodes: {
...rangeNode,
...iterateNode,
...graph.nodes,
},
edges,
};
};

View File

@@ -1,43 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { TextToImageInvocation } from 'services/api';
export const buildTxt2ImgNode = (
state: RootState
): Record<string, TextToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
sampler,
seamless,
shouldRandomizeSeed,
} = generation;
// missing fields in TextToImageInvocation: strength, hires_fix
return {
[nodeId]: {
id: nodeId,
type: 'txt2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
},
};
};

View File

@@ -1,10 +1,8 @@
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
import React, { lazy, PropsWithChildren } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { buildMiddleware, store } from './app/store';
import { store } from './app/store';
import { persistor } from './persistor';
import { OpenAPI } from 'services/api';
import { InvokeTabName } from 'features/ui/store/tabMap';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
@@ -19,61 +17,18 @@ import Loading from './Loading';
// Localization
import './i18n';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
export default function Component({
apiUrl,
disabledPanels = [],
disabledTabs = [],
token,
children,
shouldTransformUrls,
}: Props) {
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
}
// reset dynamically added middlewares
resetMiddlewares();
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
// outside the provider. it's not needed until there is the possibility that we will change
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
}, [apiUrl, token]);
export default function Component(props: PropsWithChildren) {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
>
{children}
</App>
<App>{props.children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@@ -5,8 +5,6 @@ import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
@@ -15,6 +13,4 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@@ -1,7 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
@@ -26,7 +25,7 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
const intermediateImage = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null);
@@ -37,8 +36,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
tempImage.src = intermediateImage.url;
}, [intermediateImage]);
if (!intermediateImage?.boundingBox) return null;

View File

@@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
@@ -33,7 +32,6 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null;
@@ -42,12 +40,7 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
/>
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
);
} else if (isCanvasBaseLine(obj)) {
const line = (

View File

@@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
@@ -54,16 +53,11 @@ const IAICanvasStagingArea = (props: Props) => {
width,
height,
} = useAppSelector(selector);
const { getUrl } = useGetUrl();
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
x={x}
y={y}
/>
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
)}
{shouldShowStagingOutline && (
<Group>

View File

@@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
const { stageDimensions } = state;
@@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI._Image;
image: InvokeAI.Image;
}>
) => {
const { boundingBox, image } = action.payload;

View File

@@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI._Image;
image: InvokeAI.Image;
};
export type CanvasMaskLine = {
@@ -125,7 +125,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI._Image;
intermediateImage?: InvokeAI.Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;

View File

@@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
const { url, width, height } = image;
const newImage: InvokeAI._Image = {
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: shouldSaveToGallery ? 'result' : 'user',
...image,

View File

@@ -14,9 +14,8 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
// setInitialImage,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@@ -46,15 +45,11 @@ import {
FaShareAlt,
FaTrash,
} from 'react-icons/fa';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
const currentImageButtonsSelector = createSelector(
[
@@ -64,7 +59,6 @@ const currentImageButtonsSelector = createSelector(
uiSelector,
lightboxSelector,
activeTabNameSelector,
selectedImageSelector,
],
(
system: SystemState,
@@ -72,8 +66,7 @@ const currentImageButtonsSelector = createSelector(
postprocessing,
ui,
lightbox,
activeTabName,
selectedImage
activeTabName
) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
system;
@@ -98,7 +91,6 @@ const currentImageButtonsSelector = createSelector(
shouldShowImageDetails,
activeTabName,
isLightboxOpen,
selectedImage,
};
},
{
@@ -125,32 +117,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
// currentImage,
currentImage,
isLightboxOpen,
activeTabName,
selectedImage,
} = useAppSelector(currentImageButtonsSelector);
const { getUrl, shouldTransformUrls } = useGetUrl();
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!selectedImage) return;
if (!currentImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(initialImageSelected(selectedImage.name));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {
if (!selectedImage) return;
if (!currentImage) return;
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
res.blob()
);
const blob = await fetch(currentImage.url).then((res) => res.blob());
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
@@ -164,26 +150,24 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleCopyImageLink = () => {
const url = selectedImage
? shouldTransformUrls
? getUrl(selectedImage.url)
: window.location.toString() + selectedImage.url
: '';
navigator.clipboard.writeText(url).then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
navigator.clipboard
.writeText(
currentImage ? window.location.toString() + currentImage.url : ''
)
.then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
});
};
useHotkeys(
'shift+i',
() => {
if (selectedImage) {
if (currentImage) {
handleClickUseAsInitialImage();
toast({
title: t('toast.sentToImageToImage'),
@@ -201,27 +185,24 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUseAllParameters = () => {
if (!selectedImage) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
if (!currentImage) return;
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
if (currentImage.metadata?.image.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (currentImage.metadata?.image.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}
};
useHotkeys(
'a',
() => {
if (
['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
)
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
) {
handleClickUseAllParameters();
toast({
@@ -240,18 +221,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUseSeed = () => {
selectedImage?.metadata &&
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
currentImage?.metadata &&
dispatch(setSeed(currentImage.metadata.image.seed));
};
useHotkeys(
's',
() => {
if (selectedImage?.metadata?.sd_metadata?.seed) {
if (currentImage?.metadata?.image?.seed) {
handleClickUseSeed();
toast({
title: t('toast.seedSet'),
@@ -269,19 +250,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUsePrompt = useCallback(() => {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
if (currentImage?.metadata?.image?.prompt) {
setBothPrompts(currentImage?.metadata?.image?.prompt);
}
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
useHotkeys(
'p',
() => {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
if (currentImage?.metadata?.image?.prompt) {
handleClickUsePrompt();
toast({
title: t('toast.promptSet'),
@@ -299,11 +280,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUpscale = () => {
// selectedImage && dispatch(runESRGAN(selectedImage));
currentImage && dispatch(runESRGAN(currentImage));
};
useHotkeys(
@@ -327,7 +308,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
selectedImage,
currentImage,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@@ -337,7 +318,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleClickFixFaces = () => {
// selectedImage && dispatch(runFacetool(selectedImage));
currentImage && dispatch(runFacetool(currentImage));
};
useHotkeys(
@@ -361,7 +342,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
selectedImage,
currentImage,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@@ -374,10 +355,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
const handleSendToCanvas = () => {
if (!selectedImage) return;
if (!currentImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(setInitialCanvasImage(currentImage));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@@ -395,7 +376,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'i',
() => {
if (selectedImage) {
if (currentImage) {
handleClickShowImageDetails();
} else {
toast({
@@ -406,7 +387,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage, shouldShowImageDetails]
[currentImage, shouldShowImageDetails]
);
const handleLightBox = () => {
@@ -467,7 +448,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={getUrl(selectedImage!.url)}>
<Link download={true} href={currentImage?.url}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
@@ -496,7 +477,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
isDisabled={!currentImage?.metadata?.image?.prompt}
onClick={handleClickUsePrompt}
/>
@@ -504,7 +485,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
isDisabled={!currentImage?.metadata?.image?.seed}
onClick={handleClickUseSeed}
/>
@@ -514,7 +495,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
currentImage?.metadata?.image?.type
)
}
onClick={handleClickUseAllParameters}
@@ -540,7 +521,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!selectedImage ||
!currentImage ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@@ -569,7 +550,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!selectedImage ||
!currentImage ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
@@ -591,15 +572,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
{/* <DeleteImageModal image={selectedImage}>
<DeleteImageModal image={currentImage}>
<IAIIconButton
icon={<FaTrash />}
tooltip={`${t('parameters.deleteImage')} (Del)`}
aria-label={`${t('parameters.deleteImage')} (Del)`}
isDisabled={!selectedImage || !isConnected || isProcessing}
isDisabled={!currentImage || !isConnected || isProcessing}
colorScheme="error"
/>
</DeleteImageModal> */}
</DeleteImageModal>
</Flex>
);
};

View File

@@ -4,20 +4,17 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
[gallerySelector],
(gallery) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: selectedImage || intermediateImage,
hasAnImageToDisplay: currentImage || intermediateImage,
};
},
{

View File

@@ -1,46 +1,26 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { selectedImageSelector } from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector(
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
const { currentImage, intermediateImage } = gallery;
const { shouldShowImageDetails } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
imageToDisplay,
};
},
{
@@ -51,9 +31,8 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const { shouldShowImageDetails, imageToDisplay } =
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
return (
<Flex
@@ -67,49 +46,37 @@ export default function CurrentImagePreview() {
>
{imageToDisplay && (
<Image
src={
imageToDisplay.isProgressImage
? imageToDisplay.url
: getUrl(imageToDisplay.url)
}
src={imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={
!imageToDisplay.isProgressImage ? (
<CurrentImageFallback />
) : undefined
}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
sx={{
objectFit: 'contain',
maxWidth: '100%',
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
{shouldShowImageDetails && imageToDisplay && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
</Box>
)}
</Flex>
);
}

View File

@@ -52,7 +52,7 @@ interface DeleteImageModalProps {
/**
* The image to delete.
*/
image?: InvokeAI._Image;
image?: InvokeAI.Image;
}
/**

View File

@@ -9,14 +9,11 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@@ -34,7 +31,6 @@ import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
interface HoverableImageProps {
image: InvokeAI.Image;
@@ -44,7 +40,7 @@ interface HoverableImageProps {
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@@ -59,8 +55,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, name, metadata } = image;
const { getUrl } = useGetUrl();
const { url, thumbnail, uuid, metadata } = image;
const [isHovered, setIsHovered] = useState<boolean>(false);
@@ -74,9 +69,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@@ -86,8 +82,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
image.metadata && dispatch(setSeed(image.metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@@ -97,11 +92,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(initialImageSelected(image.name));
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image));
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@@ -118,7 +122,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
metadata && dispatch(setAllParameters(metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@@ -128,13 +132,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
dispatch(setAllImageToImageParameters(metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@@ -153,18 +155,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
// e.dataTransfer.effectAllowed = 'move';
e.dataTransfer.setData('invokeai/imageUuid', uuid);
e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
};
return (
@@ -177,30 +177,28 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
isDisabled={image?.metadata?.image?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
isDisabled={image?.metadata?.image?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
@@ -211,9 +209,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
{/* <DeleteImageModal image={image}>
<DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal> */}
</DeleteImageModal>
</MenuItem>
</MenuList>
)}
@@ -221,7 +219,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={name}
key={uuid}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@@ -246,7 +244,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={getUrl(thumbnail || url)}
src={thumbnail || url}
loading="lazy"
sx={{
position: 'absolute',
@@ -292,7 +290,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
{/* <DeleteImageModal image={image}>
<DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@@ -300,7 +298,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal> */}
</DeleteImageModal>
</Box>
)}
</Box>

View File

@@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
import { requestImages } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
@@ -25,44 +25,9 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
resultsAdapter,
selectResultsAll,
selectResultsTotal,
} from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery;
return currentCategory === 'result'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
);
const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -70,7 +35,7 @@ const ImageGalleryContent = () => {
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const {
// images,
images,
currentCategory,
currentImageUuid,
shouldPinGallery,
@@ -78,24 +43,12 @@ const ImageGalleryContent = () => {
galleryGridTemplateColumns,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
// areMoreImagesAvailable,
areMoreImagesAvailable,
shouldUseSingleGalleryColumn,
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => {
if (currentCategory === 'result') {
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(receivedUploadImagesPage());
}
dispatch(requestImages(currentCategory));
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@@ -250,11 +203,11 @@ const ImageGalleryContent = () => {
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{images.map((image) => {
const { name } = image;
const isSelected = currentImageUuid === name;
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={name}
key={uuid}
image={image}
isSelected={isSelected}
/>
@@ -264,7 +217,6 @@ const ImageGalleryContent = () => {
<IAIButton
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
flexShrink={0}
>
{areMoreImagesAvailable

View File

@@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
@@ -19,7 +18,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@@ -46,7 +45,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
@@ -122,7 +120,7 @@ type ImageMetadataViewerProps = {
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
) => prev.image.uuid === next.image.uuid;
// TODO: Show more interesting information in this component.
@@ -139,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const {
cfg_scale,
@@ -162,23 +160,11 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
const metadataJSON = JSON.stringify(image.metadata, null, 2);
return (
<Flex
@@ -197,49 +183,18 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
@@ -333,14 +288,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setHeight(height))}
/>
)}
{/* {init_image_path && (
{init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
@@ -453,6 +408,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>

View File

@@ -7,16 +7,6 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@@ -85,18 +75,3 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@@ -1,17 +1,14 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { imageUploaded } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI._Image>;
images: Array<InvokeAI.Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
@@ -19,33 +16,16 @@ export type AddImagesPayload = {
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI._Image[];
images: InvokeAI.Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImage?: InvokeAI.Image;
currentImageUuid: string;
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
intermediateImage?: InvokeAI.Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
};
@@ -62,7 +42,6 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@@ -90,10 +69,7 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
@@ -148,7 +124,7 @@ export const gallerySlice = createSlice({
addImage: (
state,
action: PayloadAction<{
image: InvokeAI._Image;
image: InvokeAI.Image;
category: GalleryCategory;
}>
) => {
@@ -174,10 +150,7 @@ export const gallerySlice = createSlice({
setIntermediateImage: (
state,
action: PayloadAction<
InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
}
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
>
) => {
state.intermediateImage = action.payload;
@@ -279,31 +252,9 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@@ -1,149 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
deserializeImageField,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { getUrlAlt } from 'common/util/getUrl';
import { ImageMetadata } from 'services/api';
// import { deserializeImageField } from 'services/util/deserializeImageField';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
export const resultsAdapter = createEntityAdapter<Image>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
nextPage: number; // the next page to request
};
// export type ResultsState = ReturnType<
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
// >;
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, invocation, graph_execution_state_id, source_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session: graph_execution_state_id,
source_id,
invocation,
},
},
};
// const resultImage = deserializeImageField(result.image, invocation);
resultsAdapter.addOne(state, image);
}
});
},
});
// Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@@ -0,0 +1,54 @@
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { v4 as uuidv4 } from 'uuid';
import { addImage } from '../gallerySlice';
type UploadImageConfig = {
imageFile: File;
};
export const uploadImage =
(
config: UploadImageConfig
): ThunkAction<void, RootState, unknown, AnyAction> =>
async (dispatch, getState) => {
const { imageFile } = config;
const state = getState() as RootState;
const activeTabName = activeTabNameSelector(state);
const formData = new FormData();
formData.append('file', imageFile, imageFile.name);
formData.append(
'data',
JSON.stringify({
kind: 'init',
})
);
const response = await fetch(`${window.location.origin}/upload`, {
method: 'POST',
body: formData,
});
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: 'user',
...image,
};
dispatch(addImage({ image: newImage, category: 'user' }));
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
}
};

View File

@@ -1,95 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageUploaded } from 'services/thunks/image';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
export type UploadssState = ReturnType<
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
>;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
}),
reducers: {
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const { image_name, image_url, image_type, metadata, thumbnail_url } =
response;
const uploadedImage: Image = {
name: image_name,
url: image_url,
thumbnail: thumbnail_url,
type: 'uploads',
metadata,
};
uploadsAdapter.addOne(state, uploadedImage);
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@@ -1,10 +1,9 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI._Image;
image: InvokeAI.Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@@ -23,7 +22,6 @@ export default function ReactPanZoomImage({
scaleY,
}: ReactPanZoomProps) {
const { centerView } = useTransformContext();
const { getUrl } = useGetUrl();
return (
<TransformComponent
@@ -37,7 +35,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={getUrl(image.url)}
src={image.url}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@@ -1,47 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { map } from 'lodash';
import { RootState } from 'app/store';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocations = useAppSelector(
(state: RootState) => state.nodes.invocations
);
const addNode = useCallback(
(nodeType: string) => {
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
},
[dispatch, invocations]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocations, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

View File

@@ -1,78 +0,0 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputField, OutputField } from '../types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
opacity: 0.5,
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1.7rem',
};
const outputHandleStyles: CSSProperties = {
right: '-1.7rem',
};
const requiredConnectionStyles: CSSProperties = {
opacity: 1,
};
type FieldHandleProps = {
nodeId: string;
field: InputField | OutputField;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
export const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description, connectionType } = field;
// this needs to iterate over every candicate target node, calculating graph cycles
// WIP
// const connectionEventStyles = useConnectionEventStyles(
// nodeId,
// type,
// handleType
// );
return (
<Tooltip
key={name}
label={`${title} (${type})`}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
...(connectionType === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};

View File

@@ -1,18 +0,0 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
import { map } from 'lodash';
import { FIELDS } from '../constants';
export const FieldTypeLegend = () => {
return (
<HStack>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
);
};

View File

@@ -1,104 +0,0 @@
import {
Background,
Controls,
MiniMap,
OnConnect,
OnEdgesChange,
OnNodesChange,
ReactFlow,
ConnectionLineType,
OnConnectStart,
OnConnectEnd,
Panel,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
const nodeTypes = { invocation: InvocationComponent };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(
(event) => {
dispatch(connectionEnded());
},
[dispatch]
);
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
}, [dispatch]);
return (
<ReactFlow
nodeTypes={nodeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
>
<Panel position="top-left">
<AddNodeMenu />
</Panel>
<Panel position="top-center">
<Button onClick={handleInvoke}>Will it blend?</Button>
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />
</ReactFlow>
);
};

View File

@@ -1,50 +0,0 @@
import { Box } from '@chakra-ui/react';
import { InputField } from '../types';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputField;
};
// build an individual input element based on the schema
export const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field } = props;
const { type, value } = field;
if (type === 'string') {
return <StringInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'boolean') {
return <BooleanInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'integer' || type === 'float') {
return <NumberInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'enum') {
return <EnumInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'image') {
return <ImageInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'latents') {
return <LatentsInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'model') {
return <ModelInputFieldComponent nodeId={nodeId} field={field} />;
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@@ -1,145 +0,0 @@
import { NodeProps, useReactFlow } from 'reactflow';
import {
Box,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Tooltip,
Icon,
Code,
Text,
} from '@chakra-ui/react';
import { FaInfoCircle } from 'react-icons/fa';
import { Invocation } from '../types';
import { InputFieldComponent } from './InputFieldComponent';
import { FieldHandle } from './FieldHandle';
import { isEqual, map, size } from 'lodash';
import { memo, useMemo } from 'react';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
const connectedInputFieldsSelector = createSelector(
(state: RootState) => state.nodes.edges,
(edges) => {
return edges.map((e) => e.targetHandle);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
const { id, data, selected } = props;
const { type, title, description, inputs, outputs } = data;
const isValidConnection = useIsValidConnection();
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
// TODO: determine if a field/handle is connected and disable the input if so
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex flexDirection="column" gap={2}>
<>
<Code>{id}</Code>
<HStack justifyContent="space-between">
<Heading size="sm" fontWeight={500} color="base.100">
{title}
</Heading>
<Tooltip
label={description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.300" as={FaInfoCircle} />
</Tooltip>
</HStack>
{map(inputs, (input, i) => {
const isConnected = connectedInputs.includes(input.name);
return (
<Box
key={i}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor:
!isConnected && input.connectionType === 'always'
? 'warning.400'
: undefined,
}}
>
<FormControl isDisabled={isConnected}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>{input.title}</FormLabel>
<Tooltip
label={input.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.400" as={FaInfoCircle} />
</Tooltip>
</HStack>
<InputFieldComponent nodeId={id} field={input} />
</FormControl>
{input.connectionType !== 'never' && (
<FieldHandle
nodeId={id}
field={input}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</Box>
);
})}
{map(outputs).map((output, i) => {
// const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
const { name, title } = output;
return (
<Box
key={name}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
>
<FormControl>
<FormLabel textAlign="end">{title} Output</FormLabel>
</FormControl>
<FieldHandle
key={name}
nodeId={id}
field={output}
isValidConnection={isValidConnection}
handleType="source"
/>
</Box>
);
})}
</>
</Flex>
<Flex></Flex>
</Box>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@@ -1,46 +0,0 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/buildNodesGraph';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'md',
bg: 'base.850',
}}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
</Box>
</Box>
);
};
export default NodeEditor;

View File

@@ -1,28 +0,0 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { BooleanInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.checked,
})
);
};
return (
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
);
};

View File

@@ -1,32 +0,0 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { EnumInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{field.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@@ -1,11 +0,0 @@
import { ImageInputField } from 'features/nodes/types';
import { FaImage } from 'react-icons/fa';
import { FieldComponentProps } from './types';
export const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputField>
) => {
const { nodeId, field } = props;
return <FaImage />;
};

View File

@@ -1,11 +0,0 @@
import { LatentsInputField } from 'features/nodes/types';
import { TbBrandMatrix } from 'react-icons/tb';
import { FieldComponentProps } from './types';
export const LatentsInputFieldComponent = (
props: FieldComponentProps<LatentsInputField>
) => {
const { nodeId, field } = props;
return <TbBrandMatrix />;
};

View File

@@ -1,49 +0,0 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { ModelInputField } from 'features/nodes/types';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
(state: RootState) => state.models.modelList,
(modelList) => {
return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const availableModels = useAppSelector(availableModelsSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{availableModels.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@@ -1,33 +0,0 @@
import {
NumberDecrementStepper,
NumberIncrementStepper,
NumberInput,
NumberInputField,
NumberInputStepper,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { IntegerInputField, FloatInputField } from 'features/nodes/types';
import { FieldComponentProps } from './types';
export const NumberInputFieldComponent = (
props: FieldComponentProps<IntegerInputField | FloatInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (_: string, value: number) => {
dispatch(fieldValueChanged({ nodeId, fieldId: field.name, value }));
};
return (
<NumberInput onChange={handleValueChanged} value={field.value}>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
);
};

View File

@@ -1,22 +0,0 @@
import { Input } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { StringInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const StringInputFieldComponent = (
props: FieldComponentProps<StringInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({ nodeId, fieldId: field.name, value: e.target.value })
);
};
return <Input onChange={handleValueChanged} value={field.value}></Input>;
};

View File

@@ -1,6 +0,0 @@
import { InputField } from 'features/nodes/types';
export type FieldComponentProps<T extends InputField> = {
nodeId: string;
field: T;
};

View File

@@ -1,57 +0,0 @@
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer',
number: 'float',
string: 'string',
boolean: 'boolean',
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
model: 'model',
};
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
color: 'red',
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
color: 'orange',
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
color: 'yellow',
title: 'String',
description: 'Strings are text.',
},
boolean: {
color: 'green',
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
color: 'blue',
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
color: 'purple',
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
color: 'pink',
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
model: {
color: 'teal',
title: 'Model',
description: 'Models are models.',
},
};

Some files were not shown because too many files have changed in this diff Show More