mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-30 14:18:09 -05:00
Compare commits
9 Commits
maryhipp/f
...
dev/pytorc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c50448ccf | ||
|
|
76bcd4d44f | ||
|
|
50f5e1bc83 | ||
|
|
85b020f76c | ||
|
|
a7833cc9a9 | ||
|
|
919294e977 | ||
|
|
7640acfb1f | ||
|
|
5dec5b6f51 | ||
|
|
e158ad8534 |
1
.github/workflows/build-container.yml
vendored
1
.github/workflows/build-container.yml
vendored
@@ -18,6 +18,7 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
|
||||
@@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
||||
masking option:
|
||||
|
||||
```bash
|
||||
invoke> a fluffy cat eating a hotdot
|
||||
invoke> a fluffy cat eating a hotdog
|
||||
Outputs:
|
||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||
|
||||
@@ -461,8 +461,7 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == 'cuda':
|
||||
url = 'https://download.pytorch.org/whl/cu117'
|
||||
optional_modules = '[xformers]'
|
||||
url = 'https://download.pytorch.org/whl/cu118'
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
@@ -54,7 +56,9 @@ class ApiDependencies:
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images')
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
@@ -62,6 +66,7 @@ class ApiDependencies:
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Query, Request, UploadFile
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.datatypes.image import ImageResponse
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
@@ -50,35 +48,19 @@ async def upload_image(file: UploadFile, request: Request):
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(io.BytesIO(contents))
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(
|
||||
image_type, page, per_page
|
||||
)
|
||||
return result
|
||||
|
||||
279
invokeai/app/api/routers/models.py
Normal file
279
invokeai/app/api/routers/models.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
# @socketio.on("requestSystemConfig")
|
||||
# def handle_request_capabilities():
|
||||
# print(">> System config requested")
|
||||
# config = self.get_system_config()
|
||||
# config["model_list"] = self.generate.model_manager.list_models()
|
||||
# config["infill_methods"] = infill_methods()
|
||||
# socketio.emit("systemConfig", config)
|
||||
|
||||
# @socketio.on("searchForModels")
|
||||
# def handle_search_models(search_folder: str):
|
||||
# try:
|
||||
# if not search_folder:
|
||||
# socketio.emit(
|
||||
# "foundModels",
|
||||
# {"search_folder": None, "found_models": None},
|
||||
# )
|
||||
# else:
|
||||
# (
|
||||
# search_folder,
|
||||
# found_models,
|
||||
# ) = self.generate.model_manager.search_models(search_folder)
|
||||
# socketio.emit(
|
||||
# "foundModels",
|
||||
# {"search_folder": search_folder, "found_models": found_models},
|
||||
# )
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
# print("\n")
|
||||
|
||||
# @socketio.on("addNewModel")
|
||||
# def handle_add_model(new_model_config: dict):
|
||||
# try:
|
||||
# model_name = new_model_config["name"]
|
||||
# del new_model_config["name"]
|
||||
# model_attributes = new_model_config
|
||||
# if len(model_attributes["vae"]) == 0:
|
||||
# del model_attributes["vae"]
|
||||
# update = False
|
||||
# current_model_list = self.generate.model_manager.list_models()
|
||||
# if model_name in current_model_list:
|
||||
# update = True
|
||||
|
||||
# print(f">> Adding New Model: {model_name}")
|
||||
|
||||
# self.generate.model_manager.add_model(
|
||||
# model_name=model_name,
|
||||
# model_attributes=model_attributes,
|
||||
# clobber=True,
|
||||
# )
|
||||
# self.generate.model_manager.commit(opt.conf)
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "newModelAdded",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": update,
|
||||
# },
|
||||
# )
|
||||
# print(f">> New Model Added: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("deleteModel")
|
||||
# def handle_delete_model(model_name: str):
|
||||
# try:
|
||||
# print(f">> Deleting Model: {model_name}")
|
||||
# self.generate.model_manager.del_model(model_name)
|
||||
# self.generate.model_manager.commit(opt.conf)
|
||||
# updated_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelDeleted",
|
||||
# {
|
||||
# "deleted_model_name": model_name,
|
||||
# "model_list": updated_model_list,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Deleted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("requestModelChange")
|
||||
# def handle_set_model(model_name: str):
|
||||
# try:
|
||||
# print(f">> Model change requested: {model_name}")
|
||||
# model = self.generate.set_model(model_name)
|
||||
# model_list = self.generate.model_manager.list_models()
|
||||
# if model is None:
|
||||
# socketio.emit(
|
||||
# "modelChangeFailed",
|
||||
# {"model_name": model_name, "model_list": model_list},
|
||||
# )
|
||||
# else:
|
||||
# socketio.emit(
|
||||
# "modelChanged",
|
||||
# {"model_name": model_name, "model_list": model_list},
|
||||
# )
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
@@ -14,7 +14,7 @@ from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
@@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
|
||||
@@ -4,9 +4,9 @@ from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField
|
||||
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
@@ -47,7 +47,7 @@ def add_parsers(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
@@ -56,7 +56,7 @@ def add_parsers(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
@@ -201,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.execution_graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
from typing import (
|
||||
@@ -12,6 +13,8 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
@@ -20,7 +23,7 @@ from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@@ -44,7 +47,7 @@ def add_invocation_args(command_parser):
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
@@ -94,6 +97,9 @@ def generate_matching_edges(
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
@@ -149,7 +155,8 @@ def invoke_cli():
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@@ -162,6 +169,8 @@ def invoke_cli():
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
@@ -227,7 +236,11 @@ def invoke_cli():
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
link_node = context.session.graph.get_node(link)
|
||||
node_id = link
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
@@ -237,10 +250,15 @@ def invoke_cli():
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
||||
@@ -1,38 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.datatypes.metadata import ImageMetadata
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"image_type",
|
||||
"image_name",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageMetadata = Field(description="The image's metadata")
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's metadata"""
|
||||
|
||||
timestamp: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The width of the image in pixels")
|
||||
# TODO: figure out metadata
|
||||
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
||||
50
invokeai/app/invocations/collections.py
Normal file
50
invokeai/app/invocations/collections.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
||||
@@ -7,9 +7,9 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
|
||||
@@ -8,13 +8,12 @@ from torch import Tensor
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..datatypes.exceptions import CanceledException
|
||||
from ..util.step_callback import diffusers_step_callback_adapter
|
||||
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
|
||||
@@ -7,10 +7,20 @@ import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
#fmt: off
|
||||
|
||||
321
invokeai/app/invocations/latent.py
Normal file
321
invokeai/app/invocations/latent.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
#fmt: on
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(latent_channels, 4)
|
||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
device=use_device,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
# if self.perlin > 0.0:
|
||||
# perlin_noise = self.get_perlin_noise(
|
||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||
# )
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(CUDA_DEVICE)
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = model_manager.get_model(self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.sampler_name
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
# Latent to image
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = context.services.model_manager.get_model(self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
68
invokeai/app/invocations/math.py
Normal file
68
invokeai/app/invocations/math.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
@@ -3,10 +3,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
|
||||
@@ -1069,9 +1069,8 @@ class GraphExecutionState(BaseModel):
|
||||
n
|
||||
for n in prepared_nodes
|
||||
if all(
|
||||
pit
|
||||
nx.has_path(execution_graph, pit[0], n)
|
||||
for pit in parent_iterators
|
||||
if nx.has_path(execution_graph, pit[0], n)
|
||||
)
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -2,24 +2,24 @@
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Dict
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from pydantic import BaseModel
|
||||
from invokeai.app.datatypes.image import ImageField, ImageResponse, ImageType
|
||||
from invokeai.app.datatypes.metadata import ImageMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@@ -27,17 +27,9 @@ class ImageStorageBase(ABC):
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -79,75 +71,19 @@ class DiskImageStorage(ImageStorageBase):
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
# TODO: do all platforms support `getmtime`? seem to recall some do not...
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=os.path.basename(path),
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
timestamp=int(os.path.splitext(filename)[0].split("_")[-1]),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = PILImage.open(image_path)
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", image_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
@@ -165,19 +101,12 @@ class DiskImageStorage(ImageStorageBase):
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
@@ -11,6 +12,7 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
@@ -24,6 +26,7 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
@@ -32,6 +35,7 @@ class InvocationServices:
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
@@ -33,7 +33,6 @@ class Invoker:
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f"queueing item {invocation.id}")
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
|
||||
93
invokeai/app/services/latent_storage.py
Normal file
93
invokeai/app/services/latent_storage.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.set(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
|
||||
@@ -4,7 +4,7 @@ from threading import Event, Thread
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..datatypes.exceptions import CanceledException
|
||||
from ..util.util import CanceledException
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
|
||||
@@ -106,12 +106,10 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=page_count, per_page=per_page, total=count
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
# Generate the OpenAPI schema json
|
||||
|
||||
import json
|
||||
from invokeai.app.api_app import app
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
openapi_doc = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
with open("./openapi.json", "w") as f:
|
||||
json.dump(openapi_doc, f)
|
||||
@@ -1,16 +1,14 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
def fast_latents_step_callback(
|
||||
sample: torch.Tensor,
|
||||
step: int,
|
||||
steps: int,
|
||||
id: str,
|
||||
context: InvocationContext,
|
||||
):
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
@@ -23,12 +21,15 @@ def fast_latents_step_callback(
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
id,
|
||||
{"width": width, "height": height, "dataURL": dataURL},
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
steps,
|
||||
)
|
||||
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||
@@ -36,8 +37,6 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return fast_latents_step_callback(
|
||||
progress_state.latents, progress_state.step, **kwargs
|
||||
)
|
||||
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||
else:
|
||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
||||
@@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
# FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
|
||||
# self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
|
||||
@@ -6,5 +6,3 @@ stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
||||
@@ -3,8 +3,4 @@ dist/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
# InvokeAI Web UI
|
||||
|
||||
- [InvokeAI Web UI](#invokeai-web-ui)
|
||||
- [Stack](#stack)
|
||||
- [Contributing](#contributing)
|
||||
- [Dev Environment](#dev-environment)
|
||||
- [Production builds](#production-builds)
|
||||
|
||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||
|
||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||
|
||||
## Stack
|
||||
## Details
|
||||
|
||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||
|
||||
@@ -38,7 +32,7 @@ Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
||||
|
||||
### Production builds
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
# Generated axios API client
|
||||
|
||||
- [Generated axios API client](#generated-axios-api-client)
|
||||
- [Generation](#generation)
|
||||
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
|
||||
- [Generate the API client from JSON](#generate-the-api-client-from-json)
|
||||
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
|
||||
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
|
||||
- [Generate the API client](#generate-the-api-client)
|
||||
- [The generated client](#the-generated-client)
|
||||
- [API client customisation](#api-client-customisation)
|
||||
|
||||
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
|
||||
|
||||
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
|
||||
|
||||
## Generation
|
||||
|
||||
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
|
||||
|
||||
### Generate the API client from the nodes web server
|
||||
|
||||
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
|
||||
|
||||
1. Start the nodes web server.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python scripts/invoke-new.py --web
|
||||
```
|
||||
|
||||
2. Generate the API client.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:web
|
||||
```
|
||||
|
||||
### Generate the API client from JSON
|
||||
|
||||
The JSON can be acquired from the nodes web server, or with a python script.
|
||||
|
||||
#### Getting the JSON from the nodes web server
|
||||
|
||||
Start the nodes web server as described above, then download the file.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
curl http://localhost:9090/openapi.json -o openapi.json
|
||||
```
|
||||
|
||||
#### Getting the JSON with a python script
|
||||
|
||||
Run this python script from the repo root, so it can access the nodes server modules.
|
||||
|
||||
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python invokeai/app/util/generate_openapi_json.py
|
||||
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
|
||||
```
|
||||
|
||||
#### Generate the API client
|
||||
|
||||
Now we can generate the API client from the JSON.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:file
|
||||
```
|
||||
|
||||
## The generated client
|
||||
|
||||
The client will be written to `invokeai/frontend/web/services/api/`:
|
||||
|
||||
- `axios` client
|
||||
- TS types
|
||||
- An easily parseable schema, which we can use to generate UI
|
||||
|
||||
## API client customisation
|
||||
|
||||
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
|
||||
|
||||
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
|
||||
|
||||
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.
|
||||
@@ -1,21 +0,0 @@
|
||||
# Events
|
||||
|
||||
Events via `socket.io`
|
||||
|
||||
## `actions.ts`
|
||||
|
||||
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
|
||||
|
||||
Any reducer (or middleware) can respond to the actions.
|
||||
|
||||
## `middleware.ts`
|
||||
|
||||
Redux middleware for events.
|
||||
|
||||
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
|
||||
|
||||
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
|
||||
|
||||
## `types.ts`
|
||||
|
||||
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.
|
||||
@@ -1,29 +0,0 @@
|
||||
# Package Scripts
|
||||
|
||||
WIP walkthrough of `package.json` scripts.
|
||||
|
||||
## `theme` & `theme:watch`
|
||||
|
||||
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
|
||||
|
||||
The CLI essentially monkeypatches Chakra's files in `node_modules`.
|
||||
|
||||
## `postinstall`
|
||||
|
||||
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
|
||||
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
||||
20
invokeai/frontend/web/index.d.ts
vendored
20
invokeai/frontend/web/index.d.ts
vendored
@@ -1,7 +1,6 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export {};
|
||||
|
||||
@@ -65,24 +64,9 @@ declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
|
||||
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||
public constructor(props: StatusIndicatorProps);
|
||||
}
|
||||
|
||||
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||
public constructor(props: ModelSelectProps);
|
||||
}
|
||||
}
|
||||
|
||||
interface InvokeProps extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
}
|
||||
|
||||
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
@@ -90,7 +74,5 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
export = Invoke;
|
||||
|
||||
@@ -5,10 +5,7 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@@ -46,7 +43,7 @@
|
||||
"@chakra-ui/theme-tools": "^2.0.16",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@reduxjs/toolkit": "^1.9.3",
|
||||
"@reduxjs/toolkit": "^1.9.2",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.2.9",
|
||||
@@ -86,7 +83,6 @@
|
||||
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||
"@typescript-eslint/parser": "^5.52.0",
|
||||
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||
"axios": "^1.3.4",
|
||||
"babel-plugin-transform-imports": "^2.0.0",
|
||||
"concurrently": "^7.6.0",
|
||||
"eslint": "^8.34.0",
|
||||
@@ -94,16 +90,13 @@
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"form-data": "^4.0.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.1.2",
|
||||
"madge": "^6.0.0",
|
||||
"openapi-typescript-codegen": "^0.23.0",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.4",
|
||||
"rollup-plugin-visualizer": "^5.9.0",
|
||||
"terser": "^5.16.4",
|
||||
"typescript": "4.9.5",
|
||||
"vite": "^4.1.2",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.0.5",
|
||||
|
||||
@@ -522,9 +522,6 @@
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
"disconnected": "Disconnected from Server",
|
||||
"connected": "Connected to Server",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
|
||||
@@ -13,34 +13,16 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
options: {
|
||||
disabledPanels: string[];
|
||||
disabledTabs: InvokeTabName[];
|
||||
};
|
||||
}
|
||||
|
||||
const App = (props: Props) => {
|
||||
const App = (props: PropsWithChildren) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||
}, [dispatch, props.options.disabledPanels]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||
}, [dispatch, props.options.disabledTabs]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
|
||||
21
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
21
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@@ -14,7 +14,6 @@
|
||||
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageMetadata, ImageType } from 'services/api';
|
||||
|
||||
/**
|
||||
* TODO:
|
||||
@@ -114,7 +113,7 @@ export declare type Metadata = SystemGenerationMetadata & {
|
||||
};
|
||||
|
||||
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||
export declare type _Image = {
|
||||
export declare type Image = {
|
||||
uuid: string;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
@@ -125,23 +124,11 @@ export declare type _Image = {
|
||||
category: GalleryCategory;
|
||||
isBase64?: boolean;
|
||||
dreamPrompt?: 'string';
|
||||
name?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* ResultImage
|
||||
*/
|
||||
export declare type Image = {
|
||||
name: string;
|
||||
type: ImageType;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
metadata: ImageMetadata;
|
||||
};
|
||||
|
||||
// GalleryImages is an array of Image.
|
||||
export declare type GalleryImages = {
|
||||
images: Array<_Image>;
|
||||
images: Array<Image>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -288,7 +275,7 @@ export declare type SystemStatusResponse = SystemStatus;
|
||||
|
||||
export declare type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@@ -309,7 +296,7 @@ export declare type ErrorResponse = {
|
||||
};
|
||||
|
||||
export declare type GalleryImagesResponse = {
|
||||
images: Array<Omit<_Image, 'uuid'>>;
|
||||
images: Array<Omit<Image, 'uuid'>>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
|
||||
@@ -13,13 +13,9 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
export const generateImage = createAction<InvokeTabName>(
|
||||
'socketio/generateImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI._Image>(
|
||||
'socketio/runFacetool'
|
||||
);
|
||||
export const deleteImage = createAction<InvokeAI._Image>(
|
||||
'socketio/deleteImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
|
||||
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
||||
export const requestImages = createAction<GalleryCategory>(
|
||||
'socketio/requestImages'
|
||||
);
|
||||
|
||||
@@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
|
||||
@@ -34,9 +34,8 @@ import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
clearInitialImage,
|
||||
initialImageSelected,
|
||||
setInfillMethod,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setMaskPath,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
@@ -147,8 +146,7 @@ const makeSocketIOListeners = (
|
||||
const activeTabName = tabMap[activeTab];
|
||||
switch (activeTabName) {
|
||||
case 'img2img': {
|
||||
dispatch(initialImageSelected(newImage.uuid));
|
||||
// dispatch(setInitialImage(newImage));
|
||||
dispatch(setInitialImage(newImage));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -264,7 +262,7 @@ const makeSocketIOListeners = (
|
||||
*/
|
||||
|
||||
// Generate a UUID for each image
|
||||
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||
const preparedImages = images.map((image): InvokeAI.Image => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
...image,
|
||||
@@ -336,7 +334,7 @@ const makeSocketIOListeners = (
|
||||
|
||||
if (
|
||||
initialImage === url ||
|
||||
(initialImage as InvokeAI._Image)?.url === url
|
||||
(initialImage as InvokeAI.Image)?.url === url
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ export const socketioMiddleware = () => {
|
||||
path: `${window.location.pathname}socket.io`,
|
||||
});
|
||||
|
||||
socketio.disconnect();
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
|
||||
@@ -7,8 +7,6 @@ import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
@@ -16,7 +14,6 @@ import systemReducer from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@@ -67,10 +64,6 @@ const lightboxBlacklist = ['isLightboxOpen'].map(
|
||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||
);
|
||||
|
||||
const apiBlacklist = ['sessionId', 'status', 'progress', 'progressImage'].map(
|
||||
(blacklistItem) => `api.${blacklistItem}`
|
||||
);
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
generation: generationReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
@@ -79,8 +72,6 @@ const rootReducer = combineReducers({
|
||||
canvas: canvasReducer,
|
||||
ui: uiReducer,
|
||||
lightbox: lightboxReducer,
|
||||
results: resultsReducer,
|
||||
uploads: uploadsReducer,
|
||||
});
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
@@ -92,24 +83,12 @@ const rootPersistConfig = getPersistConfig({
|
||||
...systemBlacklist,
|
||||
...galleryBlacklist,
|
||||
...lightboxBlacklist,
|
||||
...apiBlacklist,
|
||||
// for now, never persist the results/uploads slices
|
||||
'results',
|
||||
'uploads',
|
||||
],
|
||||
debounce: 300,
|
||||
});
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
|
||||
// function buildMiddleware() {
|
||||
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
// return [socketMiddleware()];
|
||||
// } else {
|
||||
// return [socketioMiddleware()];
|
||||
// }
|
||||
// }
|
||||
|
||||
// Continue with store setup
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
@@ -117,7 +96,7 @@ export const store = configureStore({
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(socketMiddleware()),
|
||||
}).concat(socketioMiddleware()),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||
import { AppDispatch, RootState } from './store';
|
||||
|
||||
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||
state: RootState;
|
||||
dispatch: AppDispatch;
|
||||
}>();
|
||||
@@ -2,6 +2,7 @@ import { Box, useToast } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import {
|
||||
@@ -14,7 +15,6 @@ import {
|
||||
} from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { uploadImage } from 'services/thunks/image';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
|
||||
type ImageUploaderProps = {
|
||||
@@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(uploadImage({ formData: { file } }));
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(uploadImage({ formData: { file } }));
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
return () => {
|
||||
|
||||
@@ -14,7 +14,6 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { Graph } from 'services/api';
|
||||
import { buildImg2ImgNode, buildTxt2ImgNode } from './buildNodes';
|
||||
|
||||
function mapTabToFunction(activeTabName: InvokeTabName) {
|
||||
switch (activeTabName) {
|
||||
case 'txt2img':
|
||||
return buildTxt2ImgNode;
|
||||
|
||||
case 'img2img':
|
||||
return buildImg2ImgNode;
|
||||
|
||||
default:
|
||||
return buildTxt2ImgNode;
|
||||
}
|
||||
}
|
||||
|
||||
export const buildGraph = (state: RootState): Graph => {
|
||||
const { activeTab } = state.ui;
|
||||
const activeTabName = tabMap[activeTab];
|
||||
const nodeId = uuidv4();
|
||||
|
||||
return {
|
||||
nodes: {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
...mapTabToFunction(activeTabName)(state),
|
||||
},
|
||||
},
|
||||
};
|
||||
};
|
||||
@@ -1,145 +0,0 @@
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
ImageToImageInvocation,
|
||||
RestoreFaceInvocation,
|
||||
TextToImageInvocation,
|
||||
UpscaleInvocation,
|
||||
} from 'services/api';
|
||||
|
||||
import { _Image } from 'app/invokeai';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
|
||||
// fe todo fix model type (frontend uses null, backend uses undefined)
|
||||
// fe todo update front end to store to have whole image field (vs just name)
|
||||
// be todo add symmetry fields
|
||||
// be todo variations....
|
||||
|
||||
export function buildTxt2ImgNode(
|
||||
state: RootState
|
||||
): Omit<TextToImageInvocation, 'id'> {
|
||||
const { generation, system } = state;
|
||||
|
||||
const { shouldDisplayInProgressType, model } = system;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale: cfg_scale,
|
||||
sampler,
|
||||
seamless,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
// missing fields in TextToImageInvocation: strength, hires_fix
|
||||
return {
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
sampler_name: sampler as TextToImageInvocation['sampler_name'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
};
|
||||
}
|
||||
|
||||
export function buildImg2ImgNode(
|
||||
state: RootState
|
||||
): Omit<ImageToImageInvocation, 'id'> {
|
||||
const { generation, system } = state;
|
||||
|
||||
const { shouldDisplayInProgressType, model } = system;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale,
|
||||
sampler,
|
||||
seamless,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight: fit,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
const initialImage = initialImageSelector(state);
|
||||
|
||||
if (!initialImage) {
|
||||
// TODO: handle this
|
||||
throw 'no initial image';
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale: cfgScale,
|
||||
sampler_name: sampler as ImageToImageInvocation['sampler_name'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
image: {
|
||||
image_name: initialImage.name,
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
};
|
||||
}
|
||||
|
||||
export function buildFacetoolNode(
|
||||
state: RootState
|
||||
): Omit<RestoreFaceInvocation, 'id'> {
|
||||
const { generation, postprocessing } = state;
|
||||
|
||||
const { initialImage } = generation;
|
||||
|
||||
const { facetoolStrength: strength } = postprocessing;
|
||||
|
||||
// missing fields in RestoreFaceInvocation: type, codeformer_fidelity
|
||||
return {
|
||||
type: 'restore_face',
|
||||
image: {
|
||||
image_name:
|
||||
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
|
||||
'',
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
};
|
||||
}
|
||||
|
||||
// is this ESRGAN??
|
||||
export function buildUpscaleNode(
|
||||
state: RootState
|
||||
): Omit<UpscaleInvocation, 'id'> {
|
||||
const { generation, postprocessing } = state;
|
||||
|
||||
const { initialImage } = generation;
|
||||
|
||||
const { upscalingLevel: level, upscalingStrength: strength } = postprocessing;
|
||||
|
||||
// missing fields in UpscaleInvocation: denoise_str
|
||||
return {
|
||||
type: 'upscale',
|
||||
image: {
|
||||
image_name:
|
||||
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
|
||||
'',
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
level,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
||||
@@ -1,10 +1,8 @@
|
||||
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
import { persistor } from './persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
@@ -23,45 +21,18 @@ import './i18n';
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
}
|
||||
|
||||
export default function Component({
|
||||
apiUrl,
|
||||
disabledPanels = [],
|
||||
disabledTabs = [],
|
||||
token,
|
||||
children,
|
||||
}: Props) {
|
||||
const [ready, setReady] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
console.log('setting OPENAPI.BASE to', apiUrl);
|
||||
if (apiUrl) OpenAPI.BASE = apiUrl;
|
||||
setReady(true);
|
||||
}, [apiUrl]);
|
||||
|
||||
useEffect(() => {
|
||||
if (token) OpenAPI.TOKEN = token;
|
||||
}, [token]);
|
||||
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
{ready && (
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App options={{ disabledPanels, disabledTabs }}>{children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
</Provider>
|
||||
)}
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App>{props.children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
@@ -15,6 +13,4 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
|
||||
@@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
|
||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||
state.cursorPosition = action.payload;
|
||||
},
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
const image = action.payload;
|
||||
const { stageDimensions } = state;
|
||||
|
||||
@@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
boundingBox: IRect;
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
}>
|
||||
) => {
|
||||
const { boundingBox, image } = action.payload;
|
||||
|
||||
@@ -37,7 +37,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
@@ -125,7 +125,7 @@ export interface CanvasState {
|
||||
cursorPosition: Vector2d | null;
|
||||
doesCanvasNeedScaling: boolean;
|
||||
futureLayerStates: CanvasLayerState[];
|
||||
intermediateImage?: InvokeAI._Image;
|
||||
intermediateImage?: InvokeAI.Image;
|
||||
isCanvasInitialized: boolean;
|
||||
isDrawing: boolean;
|
||||
isMaskEnabled: boolean;
|
||||
|
||||
@@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
|
||||
|
||||
const { url, width, height } = image;
|
||||
|
||||
const newImage: InvokeAI._Image = {
|
||||
const newImage: InvokeAI.Image = {
|
||||
uuid: uuidv4(),
|
||||
category: shouldSaveToGallery ? 'result' : 'user',
|
||||
...image,
|
||||
|
||||
@@ -14,9 +14,8 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllParameters,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
@@ -130,10 +129,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const handleClickUseAsInitialImage = () => {
|
||||
if (!currentImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
dispatch(initialImageSelected(currentImage.uuid));
|
||||
// dispatch(setInitialImage(currentImage));
|
||||
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
dispatch(setInitialImage(currentImage));
|
||||
dispatch(setActiveTab('img2img'));
|
||||
};
|
||||
|
||||
const handleCopyImage = async () => {
|
||||
|
||||
@@ -4,20 +4,17 @@ import { useAppSelector } from 'app/storeHooks';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
import { MdPhoto } from 'react-icons/md';
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
|
||||
export const currentImageDisplaySelector = createSelector(
|
||||
[gallerySelector, selectedImageSelector],
|
||||
(gallery, selectedImage) => {
|
||||
[gallerySelector],
|
||||
(gallery) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
|
||||
return {
|
||||
hasAnImageToDisplay: selectedImage || intermediateImage,
|
||||
hasAnImageToDisplay: currentImage || intermediateImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,44 +1,26 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, selectedImageSelector, systemSelector],
|
||||
(ui, selectedImage, system) => {
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery: GalleryState, ui) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { progressImage } = system;
|
||||
|
||||
// TODO: Clean this up, this is really gross
|
||||
const imageToDisplay = progressImage
|
||||
? {
|
||||
url: progressImage.dataURL,
|
||||
width: progressImage.width,
|
||||
height: progressImage.height,
|
||||
isProgressImage: true,
|
||||
image: progressImage,
|
||||
}
|
||||
: selectedImage
|
||||
? {
|
||||
url: selectedImage.url,
|
||||
width: selectedImage.metadata.width,
|
||||
height: selectedImage.metadata.height,
|
||||
isProgressImage: false,
|
||||
image: selectedImage,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
|
||||
isIntermediate: Boolean(intermediateImage),
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -49,9 +31,9 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
export default function CurrentImagePreview() {
|
||||
const { shouldShowImageDetails, imageToDisplay } =
|
||||
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
|
||||
useAppSelector(imagesSelector);
|
||||
console.log(imageToDisplay);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
@@ -67,42 +49,34 @@ export default function CurrentImagePreview() {
|
||||
src={imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={
|
||||
!imageToDisplay.isProgressImage ? (
|
||||
<CurrentImageFallback />
|
||||
) : undefined
|
||||
}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
height: 'auto',
|
||||
position: 'absolute',
|
||||
imageRendering: imageToDisplay.isProgressImage
|
||||
? 'pixelated'
|
||||
: 'initial',
|
||||
imageRendering: isIntermediate ? 'pixelated' : 'initial',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
||||
{shouldShowImageDetails &&
|
||||
imageToDisplay &&
|
||||
'metadata' in imageToDisplay.image && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay.image} />
|
||||
</Box>
|
||||
)}
|
||||
{shouldShowImageDetails && imageToDisplay && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay} />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ interface DeleteImageModalProps {
|
||||
/**
|
||||
* The image to delete.
|
||||
*/
|
||||
image?: InvokeAI._Image;
|
||||
image?: InvokeAI.Image;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,14 +9,11 @@ import {
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
setCurrentImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllImageToImageParameters,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { DragEvent, memo, useState } from 'react';
|
||||
@@ -43,7 +40,7 @@ interface HoverableImageProps {
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
|
||||
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
@@ -58,7 +55,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(hoverableImageSelector);
|
||||
const { image, isSelected } = props;
|
||||
const { url, thumbnail, name, metadata } = image;
|
||||
const { url, thumbnail, uuid, metadata } = image;
|
||||
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
|
||||
@@ -72,9 +69,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleUsePrompt = () => {
|
||||
if (image.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(image.metadata?.sd_metadata?.prompt);
|
||||
if (image.metadata?.image?.prompt) {
|
||||
setBothPrompts(image.metadata?.image?.prompt);
|
||||
}
|
||||
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
status: 'success',
|
||||
@@ -84,8 +82,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseSeed = () => {
|
||||
image.metadata.sd_metadata &&
|
||||
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
|
||||
image.metadata && dispatch(setSeed(image.metadata.image.seed));
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
status: 'success',
|
||||
@@ -95,11 +92,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleSendToImageToImage = () => {
|
||||
dispatch(initialImageSelected(image.name));
|
||||
dispatch(setInitialImage(image));
|
||||
if (activeTabName !== 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
}
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
// dispatch(setInitialCanvasImage(image));
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
|
||||
@@ -116,7 +122,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseAllParameters = () => {
|
||||
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
|
||||
metadata && dispatch(setAllParameters(metadata));
|
||||
toast({
|
||||
title: t('toast.parametersSet'),
|
||||
status: 'success',
|
||||
@@ -126,13 +132,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseInitialImage = async () => {
|
||||
if (metadata.sd_metadata?.image?.init_image_path) {
|
||||
const response = await fetch(
|
||||
metadata.sd_metadata?.image?.init_image_path
|
||||
);
|
||||
if (metadata?.image?.init_image_path) {
|
||||
const response = await fetch(metadata.image.init_image_path);
|
||||
if (response.ok) {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
|
||||
dispatch(setAllImageToImageParameters(metadata));
|
||||
toast({
|
||||
title: t('toast.initialImageSet'),
|
||||
status: 'success',
|
||||
@@ -151,18 +155,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleSelectImage = () => {
|
||||
dispatch(imageSelected(image.name));
|
||||
};
|
||||
const handleSelectImage = () => dispatch(setCurrentImage(image));
|
||||
|
||||
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
||||
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
||||
// e.dataTransfer.effectAllowed = 'move';
|
||||
e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
};
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
dispatch(setCurrentImage(image));
|
||||
dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -175,30 +177,28 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUsePrompt}
|
||||
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
|
||||
isDisabled={image?.metadata?.image?.prompt === undefined}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
onClickCapture={handleUseSeed}
|
||||
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
|
||||
isDisabled={image?.metadata?.image?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
image?.metadata?.sd_metadata?.type
|
||||
)
|
||||
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseInitialImage}
|
||||
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
|
||||
isDisabled={image?.metadata?.image?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem>
|
||||
@@ -209,9 +209,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
<MenuItem data-warning>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<DeleteImageModal image={image}>
|
||||
<p>{t('parameters.deleteImage')}</p>
|
||||
</DeleteImageModal> */}
|
||||
</DeleteImageModal>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
@@ -219,7 +219,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={name}
|
||||
key={uuid}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
userSelect="none"
|
||||
@@ -290,7 +290,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
insetInlineEnd: 1,
|
||||
}}
|
||||
>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<DeleteImageModal image={image}>
|
||||
<IAIIconButton
|
||||
aria-label={t('parameters.deleteImage')}
|
||||
icon={<FaTrashAlt />}
|
||||
@@ -298,7 +298,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
fontSize={14}
|
||||
isDisabled={!mayDeleteImage}
|
||||
/>
|
||||
</DeleteImageModal> */}
|
||||
</DeleteImageModal>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
|
||||
import { requestImages } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@@ -25,44 +25,9 @@ import HoverableImage from './HoverableImage';
|
||||
|
||||
import Scrollable from 'features/ui/components/common/Scrollable';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import {
|
||||
resultsAdapter,
|
||||
selectResultsAll,
|
||||
selectResultsTotal,
|
||||
} from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
|
||||
const gallerySelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state.uploads,
|
||||
(state: RootState) => state.results,
|
||||
(state: RootState) => state.gallery,
|
||||
],
|
||||
(uploads, results, gallery) => {
|
||||
const { currentCategory } = gallery;
|
||||
|
||||
return currentCategory === 'result'
|
||||
? {
|
||||
images: resultsAdapter.getSelectors().selectAll(results),
|
||||
isLoading: results.isLoading,
|
||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||
}
|
||||
: {
|
||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||
isLoading: uploads.isLoading,
|
||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@@ -70,7 +35,7 @@ const ImageGalleryContent = () => {
|
||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||
|
||||
const {
|
||||
// images,
|
||||
images,
|
||||
currentCategory,
|
||||
currentImageUuid,
|
||||
shouldPinGallery,
|
||||
@@ -78,24 +43,12 @@ const ImageGalleryContent = () => {
|
||||
galleryGridTemplateColumns,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
// areMoreImagesAvailable,
|
||||
areMoreImagesAvailable,
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(imageGallerySelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
useAppSelector(gallerySelector);
|
||||
|
||||
// const handleClickLoadMore = () => {
|
||||
// dispatch(requestImages(currentCategory));
|
||||
// };
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'result') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (currentCategory === 'user') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
dispatch(requestImages(currentCategory));
|
||||
};
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
@@ -250,11 +203,11 @@ const ImageGalleryContent = () => {
|
||||
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
||||
>
|
||||
{images.map((image) => {
|
||||
const { name } = image;
|
||||
const isSelected = currentImageUuid === name;
|
||||
const { uuid } = image;
|
||||
const isSelected = currentImageUuid === uuid;
|
||||
return (
|
||||
<HoverableImage
|
||||
key={name}
|
||||
key={uuid}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
@@ -264,7 +217,6 @@ const ImageGalleryContent = () => {
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
@@ -120,7 +120,7 @@ type ImageMetadataViewerProps = {
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.name === next.image.name;
|
||||
) => prev.image.uuid === next.image.uuid;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
@@ -137,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
|
||||
const metadata = image?.metadata.sd_metadata || {};
|
||||
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
||||
const metadata = image?.metadata?.image || {};
|
||||
const dreamPrompt = image?.dreamPrompt;
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
@@ -160,7 +160,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
type,
|
||||
variations,
|
||||
width,
|
||||
model_weights,
|
||||
} = metadata;
|
||||
|
||||
const { t } = useTranslation();
|
||||
@@ -194,8 +193,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Generation type" value={type} />}
|
||||
{model_weights && (
|
||||
<MetadataItem label="Model" value={model_weights} />
|
||||
{image.metadata?.model_weights && (
|
||||
<MetadataItem label="Model" value={image.metadata.model_weights} />
|
||||
)}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
@@ -289,14 +288,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
onClick={() => dispatch(setHeight(height))}
|
||||
/>
|
||||
)}
|
||||
{/* {init_image_path && (
|
||||
{init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
)}
|
||||
{mask_image_path && (
|
||||
<MetadataItem
|
||||
label="Mask image"
|
||||
|
||||
@@ -7,16 +7,6 @@ import {
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import {
|
||||
selectResultsAll,
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from './resultsSlice';
|
||||
import {
|
||||
selectUploadsAll,
|
||||
selectUploadsById,
|
||||
selectUploadsEntities,
|
||||
} from './uploadsSlice';
|
||||
|
||||
export const gallerySelector = (state: RootState) => state.gallery;
|
||||
|
||||
@@ -85,18 +75,3 @@ export const hoverableImageSelector = createSelector(
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const selectedImageSelector = createSelector(
|
||||
[gallerySelector, selectResultsEntities, selectUploadsEntities],
|
||||
(gallery, allResults, allUploads) => {
|
||||
const selectedImageName = gallery.selectedImageName;
|
||||
|
||||
if (selectedImageName in allResults) {
|
||||
return allResults[selectedImageName];
|
||||
}
|
||||
|
||||
if (selectedImageName in allUploads) {
|
||||
return allUploads[selectedImageName];
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { uploadImage } from 'services/thunks/image';
|
||||
|
||||
export type GalleryCategory = 'user' | 'result';
|
||||
|
||||
export type AddImagesPayload = {
|
||||
images: Array<InvokeAI._Image>;
|
||||
images: Array<InvokeAI.Image>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
@@ -19,33 +16,16 @@ export type AddImagesPayload = {
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
export type Gallery = {
|
||||
images: InvokeAI._Image[];
|
||||
images: InvokeAI.Image[];
|
||||
latest_mtime?: number;
|
||||
earliest_mtime?: number;
|
||||
areMoreImagesAvailable: boolean;
|
||||
};
|
||||
|
||||
export interface GalleryState {
|
||||
/**
|
||||
* The selected image's unique name
|
||||
* Use `selectedImageSelector` to access the image
|
||||
*/
|
||||
selectedImageName: string;
|
||||
/**
|
||||
* The currently selected image
|
||||
* @deprecated See `state.gallery.selectedImageName`
|
||||
*/
|
||||
currentImage?: InvokeAI._Image;
|
||||
/**
|
||||
* The currently selected image's uuid.
|
||||
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
|
||||
*/
|
||||
currentImage?: InvokeAI.Image;
|
||||
currentImageUuid: string;
|
||||
/**
|
||||
* The current progress image
|
||||
* @deprecated See `state.system.progressImage`
|
||||
*/
|
||||
intermediateImage?: InvokeAI._Image & {
|
||||
intermediateImage?: InvokeAI.Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
};
|
||||
@@ -62,7 +42,6 @@ export interface GalleryState {
|
||||
}
|
||||
|
||||
const initialState: GalleryState = {
|
||||
selectedImageName: '',
|
||||
currentImageUuid: '',
|
||||
galleryImageMinimumWidth: 64,
|
||||
galleryImageObjectFit: 'cover',
|
||||
@@ -90,10 +69,7 @@ export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState,
|
||||
reducers: {
|
||||
imageSelected: (state, action: PayloadAction<string>) => {
|
||||
state.selectedImageName = action.payload;
|
||||
},
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
state.currentImage = action.payload;
|
||||
state.currentImageUuid = action.payload.uuid;
|
||||
},
|
||||
@@ -148,7 +124,7 @@ export const gallerySlice = createSlice({
|
||||
addImage: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
category: GalleryCategory;
|
||||
}>
|
||||
) => {
|
||||
@@ -174,10 +150,7 @@ export const gallerySlice = createSlice({
|
||||
setIntermediateImage: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
InvokeAI._Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
}
|
||||
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
|
||||
>
|
||||
) => {
|
||||
state.intermediateImage = action.payload;
|
||||
@@ -279,31 +252,9 @@ export const gallerySlice = createSlice({
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
if (isImageOutput(data.result)) {
|
||||
state.selectedImageName = data.result.image.image_name;
|
||||
state.intermediateImage = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(uploadImage.fulfilled, (state, action) => {
|
||||
const location = action.payload;
|
||||
const imageName = location.split('/').pop() || '';
|
||||
state.selectedImageName = imageName;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
imageSelected,
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
removeImage,
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
// import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
|
||||
// use `createEntityAdapter` to create a slice for results images
|
||||
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
|
||||
|
||||
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
|
||||
export const resultsAdapter = createEntityAdapter<Image>({
|
||||
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
|
||||
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
|
||||
selectId: (image) => image.name,
|
||||
// Order all images by their time (in descending order)
|
||||
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
|
||||
});
|
||||
|
||||
// This type is intersected with the Entity type to create the shape of the state
|
||||
type AdditionalResultsState = {
|
||||
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
|
||||
// to list all images directly at this time...
|
||||
page: number; // current page we are on
|
||||
pages: number; // the total number of pages available
|
||||
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
|
||||
nextPage: number; // the next page to request
|
||||
};
|
||||
|
||||
const resultsSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: resultsAdapter.getInitialState<AdditionalResultsState>({
|
||||
// provide the additional initial state
|
||||
page: 0,
|
||||
pages: 0,
|
||||
isLoading: false,
|
||||
nextPage: 0,
|
||||
}),
|
||||
reducers: {
|
||||
// the adapter provides some helper reducers; see the docs for all of them
|
||||
// can use them as helper functions within a reducer, or use the function itself as a reducer
|
||||
|
||||
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
|
||||
// to add a single result
|
||||
resultAdded: resultsAdapter.addOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
|
||||
// because we pass in the fulfilled thunk action creator, everything is typed
|
||||
|
||||
/**
|
||||
* Received Result Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Result Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const resultImages = items.map((image) =>
|
||||
deserializeImageResponse(image)
|
||||
);
|
||||
|
||||
// use the adapter reducer to append all the results to state
|
||||
resultsAdapter.addMany(state, resultImages);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
|
||||
if (isImageOutput(data.result)) {
|
||||
const resultImage = deserializeImageField(data.result.image);
|
||||
resultsAdapter.addOne(state, resultImage);
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
// Create a set of memoized selectors based on the location of this entity state
|
||||
// to be used as selectors in a `useAppSelector()` call
|
||||
export const {
|
||||
selectAll: selectResultsAll,
|
||||
selectById: selectResultsById,
|
||||
selectEntities: selectResultsEntities,
|
||||
selectIds: selectResultsIds,
|
||||
selectTotal: selectResultsTotal,
|
||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||
|
||||
export const { resultAdded } = resultsSlice.actions;
|
||||
|
||||
export default resultsSlice.reducer;
|
||||
@@ -0,0 +1,54 @@
|
||||
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { RootState } from 'app/store';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { addImage } from '../gallerySlice';
|
||||
|
||||
type UploadImageConfig = {
|
||||
imageFile: File;
|
||||
};
|
||||
|
||||
export const uploadImage =
|
||||
(
|
||||
config: UploadImageConfig
|
||||
): ThunkAction<void, RootState, unknown, AnyAction> =>
|
||||
async (dispatch, getState) => {
|
||||
const { imageFile } = config;
|
||||
|
||||
const state = getState() as RootState;
|
||||
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
|
||||
const formData = new FormData();
|
||||
|
||||
formData.append('file', imageFile, imageFile.name);
|
||||
formData.append(
|
||||
'data',
|
||||
JSON.stringify({
|
||||
kind: 'init',
|
||||
})
|
||||
);
|
||||
|
||||
const response = await fetch(`${window.location.origin}/upload`, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
|
||||
const newImage: InvokeAI.Image = {
|
||||
uuid: uuidv4(),
|
||||
category: 'user',
|
||||
...image,
|
||||
};
|
||||
|
||||
dispatch(addImage({ image: newImage, category: 'user' }));
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(newImage));
|
||||
} else if (activeTabName === 'img2img') {
|
||||
dispatch(setInitialImage(newImage));
|
||||
}
|
||||
};
|
||||
@@ -1,86 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { uploadImage } from 'services/thunks/image';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||
selectId: (image) => image.name,
|
||||
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
|
||||
});
|
||||
|
||||
type AdditionalUploadsState = {
|
||||
page: number;
|
||||
pages: number;
|
||||
isLoading: boolean;
|
||||
nextPage: number;
|
||||
};
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
nextPage: 0,
|
||||
isLoading: false,
|
||||
}),
|
||||
reducers: {
|
||||
uploadAdded: uploadsAdapter.addOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
* Received Upload Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Upload Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const images = items.map((image) => deserializeImageResponse(image));
|
||||
|
||||
uploadsAdapter.addMany(state, images);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(uploadImage.fulfilled, (state, action) => {
|
||||
const location = action.payload;
|
||||
|
||||
const uploadedImage = deserializeImageField({
|
||||
image_name: location.split('/').pop() || '',
|
||||
image_type: 'uploads',
|
||||
});
|
||||
|
||||
uploadsAdapter.addOne(state, uploadedImage);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectUploadsAll,
|
||||
selectById: selectUploadsById,
|
||||
selectEntities: selectUploadsEntities,
|
||||
selectIds: selectUploadsIds,
|
||||
selectTotal: selectUploadsTotal,
|
||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||
|
||||
export const { uploadAdded } = uploadsSlice.actions;
|
||||
|
||||
export default uploadsSlice.reducer;
|
||||
@@ -3,7 +3,7 @@ import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
|
||||
type ReactPanZoomProps = {
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
styleClass?: string;
|
||||
alt?: string;
|
||||
ref?: React.Ref<HTMLImageElement>;
|
||||
|
||||
@@ -21,10 +21,9 @@ type ParametersAccordionsType = {
|
||||
const ParametersAccordion = (props: ParametersAccordionsType) => {
|
||||
const { accordionInfo } = props;
|
||||
|
||||
const { system, ui } = useAppSelector((state: RootState) => state);
|
||||
|
||||
const { openAccordions } = system;
|
||||
const { disabledParameterPanels } = ui;
|
||||
const openAccordions = useAppSelector(
|
||||
(state: RootState) => state.system.openAccordions
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -40,19 +39,15 @@ const ParametersAccordion = (props: ParametersAccordionsType) => {
|
||||
Object.keys(accordionInfo).forEach((key) => {
|
||||
const { header, feature, content, additionalHeaderComponents } =
|
||||
accordionInfo[key];
|
||||
|
||||
// do not render if panel is disabled in global state
|
||||
if (disabledParameterPanels.indexOf(key) === -1) {
|
||||
accordionsToRender.push(
|
||||
<InvokeAccordionItem
|
||||
key={key}
|
||||
header={header}
|
||||
feature={feature}
|
||||
content={content}
|
||||
additionalHeaderComponents={additionalHeaderComponents}
|
||||
/>
|
||||
);
|
||||
}
|
||||
accordionsToRender.push(
|
||||
<InvokeAccordionItem
|
||||
key={key}
|
||||
header={header}
|
||||
feature={feature}
|
||||
content={content}
|
||||
additionalHeaderComponents={additionalHeaderComponents}
|
||||
/>
|
||||
);
|
||||
});
|
||||
}
|
||||
return accordionsToRender;
|
||||
|
||||
@@ -11,7 +11,6 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlay } from 'react-icons/fa';
|
||||
import { createSession } from 'services/thunks/session';
|
||||
|
||||
interface InvokeButton
|
||||
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
|
||||
@@ -25,8 +24,7 @@ export default function InvokeButton(props: InvokeButton) {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const handleClickGenerate = () => {
|
||||
// dispatch(generateImage(activeTabName));
|
||||
dispatch(createSession());
|
||||
dispatch(generateImage(activeTabName));
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
export const generationSelector = (state: RootState) => state.generation;
|
||||
@@ -21,15 +15,3 @@ export const mayGenerateMultipleImagesSelector = createSelector(
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const initialImageSelector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const { initialImage: initialImageName } = generation;
|
||||
|
||||
return (
|
||||
selectResultsById(state, initialImageName as string) ??
|
||||
selectUploadsById(state, initialImageName as string)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -11,7 +11,7 @@ export interface GenerationState {
|
||||
height: number;
|
||||
img2imgStrength: number;
|
||||
infillMethod: string;
|
||||
initialImage?: InvokeAI._Image | string; // can be an Image or url
|
||||
initialImage?: InvokeAI.Image | string; // can be an Image or url
|
||||
iterations: number;
|
||||
maskPath: string;
|
||||
perlin: number;
|
||||
@@ -317,12 +317,12 @@ export const generationSlice = createSlice({
|
||||
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldRandomizeSeed = action.payload;
|
||||
},
|
||||
// setInitialImage: (
|
||||
// state,
|
||||
// action: PayloadAction<InvokeAI._Image | string>
|
||||
// ) => {
|
||||
// state.initialImage = action.payload;
|
||||
// },
|
||||
setInitialImage: (
|
||||
state,
|
||||
action: PayloadAction<InvokeAI.Image | string>
|
||||
) => {
|
||||
state.initialImage = action.payload;
|
||||
},
|
||||
clearInitialImage: (state) => {
|
||||
state.initialImage = undefined;
|
||||
},
|
||||
@@ -353,9 +353,6 @@ export const generationSlice = createSlice({
|
||||
setVerticalSymmetrySteps: (state, action: PayloadAction<number>) => {
|
||||
state.verticalSymmetrySteps = action.payload;
|
||||
},
|
||||
initialImageSelected: (state, action: PayloadAction<string>) => {
|
||||
state.initialImage = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -371,7 +368,7 @@ export const {
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
setInfillMethod,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setIterations,
|
||||
setMaskPath,
|
||||
setParameter,
|
||||
@@ -397,7 +394,6 @@ export const {
|
||||
setShouldUseSymmetry,
|
||||
setHorizontalSymmetrySteps,
|
||||
setVerticalSymmetrySteps,
|
||||
initialImageSelected,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export default generationSlice.reducer;
|
||||
|
||||
@@ -1,24 +1,9 @@
|
||||
import { useToast, UseToastOptions } from '@chakra-ui/react';
|
||||
import { useToast } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||
import { clearToastQueue } from 'features/system/store/systemSlice';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
export type MakeToastArg = string | UseToastOptions;
|
||||
|
||||
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
|
||||
if (typeof arg === 'string') {
|
||||
return {
|
||||
title: arg,
|
||||
status: 'info',
|
||||
isClosable: true,
|
||||
duration: 2500,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'info', isClosable: true, duration: 2500, ...arg };
|
||||
};
|
||||
|
||||
const useToastWatcher = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toastQueue = useAppSelector(toastQueueSelector);
|
||||
|
||||
@@ -2,20 +2,7 @@ import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import {
|
||||
generatorProgress,
|
||||
invocationComplete,
|
||||
invocationError,
|
||||
invocationStarted,
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
} from 'services/events/actions';
|
||||
|
||||
import i18n from 'i18n';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
import { makeToast } from '../hooks/useToastWatcher';
|
||||
|
||||
export type LogLevel = 'info' | 'warning' | 'error';
|
||||
|
||||
@@ -69,10 +56,6 @@ export interface SystemState
|
||||
cancelType: CancelType;
|
||||
cancelAfter: number | null;
|
||||
};
|
||||
/**
|
||||
* The current progress image
|
||||
*/
|
||||
progressImage: ProgressImage | null;
|
||||
}
|
||||
|
||||
const initialSystemState: SystemState = {
|
||||
@@ -115,7 +98,6 @@ const initialSystemState: SystemState = {
|
||||
cancelType: 'immediate',
|
||||
cancelAfter: null,
|
||||
},
|
||||
progressImage: null,
|
||||
};
|
||||
|
||||
export const systemSlice = createSlice({
|
||||
@@ -290,111 +272,6 @@ export const systemSlice = createSlice({
|
||||
state.cancelOptions.cancelAfter = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Socket Connected
|
||||
*/
|
||||
builder.addCase(socketConnected, (state, action) => {
|
||||
const { timestamp } = action.payload;
|
||||
|
||||
state.isConnected = true;
|
||||
state.currentStatus = i18n.t('common.statusConnected');
|
||||
state.log.push({
|
||||
timestamp,
|
||||
message: `Connected to server`,
|
||||
level: 'info',
|
||||
});
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: i18n.t('toast.connected'), status: 'success' })
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Socket Disconnected
|
||||
*/
|
||||
builder.addCase(socketDisconnected, (state, action) => {
|
||||
const { timestamp } = action.payload;
|
||||
|
||||
state.isConnected = false;
|
||||
state.currentStatus = i18n.t('common.statusDisconnected');
|
||||
state.log.push({
|
||||
timestamp,
|
||||
message: `Disconnected from server`,
|
||||
level: 'error',
|
||||
});
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: i18n.t('toast.disconnected'), status: 'error' })
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Started
|
||||
*/
|
||||
builder.addCase(invocationStarted, (state) => {
|
||||
state.isProcessing = true;
|
||||
state.currentStatusHasSteps = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Generator Progress
|
||||
*/
|
||||
builder.addCase(generatorProgress, (state, action) => {
|
||||
const { step, total_steps, progress_image } = action.payload.data;
|
||||
|
||||
state.currentStatusHasSteps = true;
|
||||
state.currentStep = step + 1; // TODO: step starts at -1, think this is a bug
|
||||
state.totalSteps = total_steps;
|
||||
state.progressImage = progress_image ?? null;
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data, timestamp } = action.payload;
|
||||
|
||||
state.isProcessing = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.progressImage = null;
|
||||
|
||||
// TODO: handle logging for other invocation types
|
||||
if (isImageOutput(data.result)) {
|
||||
state.log.push({
|
||||
timestamp,
|
||||
message: `Generated: ${data.result.image.image_name}`,
|
||||
level: 'info',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Error
|
||||
*/
|
||||
builder.addCase(invocationError, (state, action) => {
|
||||
const { data, timestamp } = action.payload;
|
||||
|
||||
state.log.push({
|
||||
timestamp,
|
||||
message: `Server error: ${data.error}`,
|
||||
level: 'error',
|
||||
});
|
||||
|
||||
state.wasErrorSeen = true;
|
||||
state.progressImage = null;
|
||||
state.isProcessing = false;
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: i18n.t('toast.serverError'), status: 'error' })
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Initial Image Selected
|
||||
*/
|
||||
builder.addCase(initialImageSelected, (state) => {
|
||||
state.toastQueue.push(makeToast(i18n.t('toast.sentToImageToImage')));
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
|
||||
@@ -45,41 +45,38 @@ const tabIconStyles: ChakraProps['sx'] = {
|
||||
boxSize: 6,
|
||||
};
|
||||
|
||||
const buildTabs = (disabledTabs: InvokeTabName[]): InvokeTabInfo[] => {
|
||||
const tabs: InvokeTabInfo[] = [
|
||||
{
|
||||
id: 'txt2img',
|
||||
icon: <Icon as={MdTextFields} sx={tabIconStyles} />,
|
||||
workarea: <TextToImageWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'img2img',
|
||||
icon: <Icon as={MdPhotoLibrary} sx={tabIconStyles} />,
|
||||
workarea: <ImageToImageWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'unifiedCanvas',
|
||||
icon: <Icon as={MdGridOn} sx={tabIconStyles} />,
|
||||
workarea: <UnifiedCanvasWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'nodes',
|
||||
icon: <Icon as={MdDeviceHub} sx={tabIconStyles} />,
|
||||
workarea: <NodesWIP />,
|
||||
},
|
||||
{
|
||||
id: 'postprocessing',
|
||||
icon: <Icon as={MdPhotoFilter} sx={tabIconStyles} />,
|
||||
workarea: <PostProcessingWIP />,
|
||||
},
|
||||
{
|
||||
id: 'training',
|
||||
icon: <Icon as={MdFlashOn} sx={tabIconStyles} />,
|
||||
workarea: <TrainingWIP />,
|
||||
},
|
||||
];
|
||||
return tabs.filter((tab) => !disabledTabs.includes(tab.id));
|
||||
};
|
||||
const tabInfo: InvokeTabInfo[] = [
|
||||
{
|
||||
id: 'txt2img',
|
||||
icon: <Icon as={MdTextFields} sx={tabIconStyles} />,
|
||||
workarea: <TextToImageWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'img2img',
|
||||
icon: <Icon as={MdPhotoLibrary} sx={tabIconStyles} />,
|
||||
workarea: <ImageToImageWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'unifiedCanvas',
|
||||
icon: <Icon as={MdGridOn} sx={tabIconStyles} />,
|
||||
workarea: <UnifiedCanvasWorkarea />,
|
||||
},
|
||||
{
|
||||
id: 'nodes',
|
||||
icon: <Icon as={MdDeviceHub} sx={tabIconStyles} />,
|
||||
workarea: <NodesWIP />,
|
||||
},
|
||||
{
|
||||
id: 'postprocessing',
|
||||
icon: <Icon as={MdPhotoFilter} sx={tabIconStyles} />,
|
||||
workarea: <PostProcessingWIP />,
|
||||
},
|
||||
{
|
||||
id: 'training',
|
||||
icon: <Icon as={MdFlashOn} sx={tabIconStyles} />,
|
||||
workarea: <TrainingWIP />,
|
||||
},
|
||||
];
|
||||
|
||||
export default function InvokeTabs() {
|
||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
||||
@@ -88,10 +85,13 @@ export default function InvokeTabs() {
|
||||
(state: RootState) => state.lightbox.isLightboxOpen
|
||||
);
|
||||
|
||||
const { shouldPinGallery, disabledTabs, shouldPinParametersPanel } =
|
||||
useAppSelector((state: RootState) => state.ui);
|
||||
const shouldPinGallery = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldPinGallery
|
||||
);
|
||||
|
||||
const activeTabs = buildTabs(disabledTabs);
|
||||
const shouldPinParametersPanel = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldPinParametersPanel
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -142,7 +142,7 @@ export default function InvokeTabs() {
|
||||
|
||||
const tabs = useMemo(
|
||||
() =>
|
||||
activeTabs.map((tab) => (
|
||||
tabInfo.map((tab) => (
|
||||
<Tooltip
|
||||
key={tab.id}
|
||||
hasArrow
|
||||
@@ -157,13 +157,13 @@ export default function InvokeTabs() {
|
||||
</Tab>
|
||||
</Tooltip>
|
||||
)),
|
||||
[t, activeTabs]
|
||||
[t]
|
||||
);
|
||||
|
||||
const tabPanels = useMemo(
|
||||
() =>
|
||||
activeTabs.map((tab) => <TabPanel key={tab.id}>{tab.workarea}</TabPanel>),
|
||||
[activeTabs]
|
||||
tabInfo.map((tab) => <TabPanel key={tab.id}>{tab.workarea}</TabPanel>),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Box, BoxProps, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
uiSelector,
|
||||
@@ -47,7 +47,7 @@ const InvokeWorkarea = (props: InvokeWorkareaProps) => {
|
||||
const image = getImageByUuid(uuid);
|
||||
if (!image) return;
|
||||
if (activeTabName === 'img2img') {
|
||||
dispatch(initialImageSelected(image.uuid));
|
||||
dispatch(setInitialImage(image));
|
||||
} else if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
|
||||
@@ -96,6 +96,7 @@ const ParametersPanel = ({ children }: ParametersPanelProps) => {
|
||||
onClose={closeParametersPanel}
|
||||
isPinned={shouldPinParametersPanel || isLightboxOpen}
|
||||
sx={{
|
||||
borderColor: 'base.700',
|
||||
p: shouldPinParametersPanel ? 0 : 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import { Flex, Image, Text, useToast } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import ImageUploaderIconButton from 'common/components/ImageUploaderIconButton';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function InitImagePreview() {
|
||||
const initialImage = useAppSelector(initialImageSelector);
|
||||
const initialImage = useAppSelector(
|
||||
(state: RootState) => state.generation.initialImage
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const setActiveTabReducer = (
|
||||
state: UIState,
|
||||
newActiveTab: number | InvokeTabName
|
||||
) => {
|
||||
if (typeof newActiveTab === 'number') {
|
||||
state.activeTab = newActiveTab;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(newActiveTab);
|
||||
}
|
||||
};
|
||||
@@ -1,7 +1,5 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
|
||||
@@ -18,8 +16,6 @@ const initialtabsState: UIState = {
|
||||
addNewModelUIOption: null,
|
||||
shouldPinGallery: true,
|
||||
shouldShowGallery: true,
|
||||
disabledParameterPanels: [],
|
||||
disabledTabs: [],
|
||||
};
|
||||
|
||||
const initialState: UIState = initialtabsState;
|
||||
@@ -29,7 +25,11 @@ export const uiSlice = createSlice({
|
||||
initialState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
|
||||
setActiveTabReducer(state, action.payload);
|
||||
if (typeof action.payload === 'number') {
|
||||
state.activeTab = action.payload;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(action.payload);
|
||||
}
|
||||
},
|
||||
setCurrentTheme: (state, action: PayloadAction<string>) => {
|
||||
state.currentTheme = action.payload;
|
||||
@@ -92,19 +92,6 @@ export const uiSlice = createSlice({
|
||||
state.shouldShowParametersPanel = true;
|
||||
}
|
||||
},
|
||||
setDisabledPanels: (state, action: PayloadAction<string[]>) => {
|
||||
state.disabledParameterPanels = action.payload;
|
||||
},
|
||||
setDisabledTabs: (state, action: PayloadAction<InvokeTabName[]>) => {
|
||||
state.disabledTabs = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageSelected, (state) => {
|
||||
if (tabMap[state.activeTab] !== 'img2img') {
|
||||
setActiveTabReducer(state, 'img2img');
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -126,8 +113,6 @@ export const {
|
||||
togglePinParametersPanel,
|
||||
toggleParametersPanel,
|
||||
toggleGalleryPanel,
|
||||
setDisabledPanels,
|
||||
setDisabledTabs,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import { InvokeTabName } from './tabMap';
|
||||
|
||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||
|
||||
export interface UIState {
|
||||
@@ -15,6 +13,4 @@ export interface UIState {
|
||||
addNewModelUIOption: AddNewModelType;
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
disabledParameterPanels: string[];
|
||||
disabledTabs: InvokeTabName[];
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
import type { ApiRequestOptions } from './ApiRequestOptions';
|
||||
import type { ApiResult } from './ApiResult';
|
||||
|
||||
export class ApiError extends Error {
|
||||
public readonly url: string;
|
||||
public readonly status: number;
|
||||
public readonly statusText: string;
|
||||
public readonly body: any;
|
||||
public readonly request: ApiRequestOptions;
|
||||
|
||||
constructor(request: ApiRequestOptions, response: ApiResult, message: string) {
|
||||
super(message);
|
||||
|
||||
this.name = 'ApiError';
|
||||
this.url = response.url;
|
||||
this.status = response.status;
|
||||
this.statusText = response.statusText;
|
||||
this.body = response.body;
|
||||
this.request = request;
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export type ApiRequestOptions = {
|
||||
readonly method: 'GET' | 'PUT' | 'POST' | 'DELETE' | 'OPTIONS' | 'HEAD' | 'PATCH';
|
||||
readonly url: string;
|
||||
readonly path?: Record<string, any>;
|
||||
readonly cookies?: Record<string, any>;
|
||||
readonly headers?: Record<string, any>;
|
||||
readonly query?: Record<string, any>;
|
||||
readonly formData?: Record<string, any>;
|
||||
readonly body?: any;
|
||||
readonly mediaType?: string;
|
||||
readonly responseHeader?: string;
|
||||
readonly errors?: Record<number, string>;
|
||||
};
|
||||
@@ -1,10 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export type ApiResult = {
|
||||
readonly url: string;
|
||||
readonly ok: boolean;
|
||||
readonly status: number;
|
||||
readonly statusText: string;
|
||||
readonly body: any;
|
||||
};
|
||||
@@ -1,128 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export class CancelError extends Error {
|
||||
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = 'CancelError';
|
||||
}
|
||||
|
||||
public get isCancelled(): boolean {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
export interface OnCancel {
|
||||
readonly isResolved: boolean;
|
||||
readonly isRejected: boolean;
|
||||
readonly isCancelled: boolean;
|
||||
|
||||
(cancelHandler: () => void): void;
|
||||
}
|
||||
|
||||
export class CancelablePromise<T> implements Promise<T> {
|
||||
readonly [Symbol.toStringTag]!: string;
|
||||
|
||||
private _isResolved: boolean;
|
||||
private _isRejected: boolean;
|
||||
private _isCancelled: boolean;
|
||||
private readonly _cancelHandlers: (() => void)[];
|
||||
private readonly _promise: Promise<T>;
|
||||
private _resolve?: (value: T | PromiseLike<T>) => void;
|
||||
private _reject?: (reason?: any) => void;
|
||||
|
||||
constructor(
|
||||
executor: (
|
||||
resolve: (value: T | PromiseLike<T>) => void,
|
||||
reject: (reason?: any) => void,
|
||||
onCancel: OnCancel
|
||||
) => void
|
||||
) {
|
||||
this._isResolved = false;
|
||||
this._isRejected = false;
|
||||
this._isCancelled = false;
|
||||
this._cancelHandlers = [];
|
||||
this._promise = new Promise<T>((resolve, reject) => {
|
||||
this._resolve = resolve;
|
||||
this._reject = reject;
|
||||
|
||||
const onResolve = (value: T | PromiseLike<T>): void => {
|
||||
if (this._isResolved || this._isRejected || this._isCancelled) {
|
||||
return;
|
||||
}
|
||||
this._isResolved = true;
|
||||
this._resolve?.(value);
|
||||
};
|
||||
|
||||
const onReject = (reason?: any): void => {
|
||||
if (this._isResolved || this._isRejected || this._isCancelled) {
|
||||
return;
|
||||
}
|
||||
this._isRejected = true;
|
||||
this._reject?.(reason);
|
||||
};
|
||||
|
||||
const onCancel = (cancelHandler: () => void): void => {
|
||||
if (this._isResolved || this._isRejected || this._isCancelled) {
|
||||
return;
|
||||
}
|
||||
this._cancelHandlers.push(cancelHandler);
|
||||
};
|
||||
|
||||
Object.defineProperty(onCancel, 'isResolved', {
|
||||
get: (): boolean => this._isResolved,
|
||||
});
|
||||
|
||||
Object.defineProperty(onCancel, 'isRejected', {
|
||||
get: (): boolean => this._isRejected,
|
||||
});
|
||||
|
||||
Object.defineProperty(onCancel, 'isCancelled', {
|
||||
get: (): boolean => this._isCancelled,
|
||||
});
|
||||
|
||||
return executor(onResolve, onReject, onCancel as OnCancel);
|
||||
});
|
||||
}
|
||||
|
||||
public then<TResult1 = T, TResult2 = never>(
|
||||
onFulfilled?: ((value: T) => TResult1 | PromiseLike<TResult1>) | null,
|
||||
onRejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null
|
||||
): Promise<TResult1 | TResult2> {
|
||||
return this._promise.then(onFulfilled, onRejected);
|
||||
}
|
||||
|
||||
public catch<TResult = never>(
|
||||
onRejected?: ((reason: any) => TResult | PromiseLike<TResult>) | null
|
||||
): Promise<T | TResult> {
|
||||
return this._promise.catch(onRejected);
|
||||
}
|
||||
|
||||
public finally(onFinally?: (() => void) | null): Promise<T> {
|
||||
return this._promise.finally(onFinally);
|
||||
}
|
||||
|
||||
public cancel(): void {
|
||||
if (this._isResolved || this._isRejected || this._isCancelled) {
|
||||
return;
|
||||
}
|
||||
this._isCancelled = true;
|
||||
if (this._cancelHandlers.length) {
|
||||
try {
|
||||
for (const cancelHandler of this._cancelHandlers) {
|
||||
cancelHandler();
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Cancellation threw an error', error);
|
||||
return;
|
||||
}
|
||||
}
|
||||
this._cancelHandlers.length = 0;
|
||||
this._reject?.(new CancelError('Request aborted'));
|
||||
}
|
||||
|
||||
public get isCancelled(): boolean {
|
||||
return this._isCancelled;
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
import type { ApiRequestOptions } from './ApiRequestOptions';
|
||||
|
||||
type Resolver<T> = (options: ApiRequestOptions) => Promise<T>;
|
||||
type Headers = Record<string, string>;
|
||||
|
||||
export type OpenAPIConfig = {
|
||||
BASE: string;
|
||||
VERSION: string;
|
||||
WITH_CREDENTIALS: boolean;
|
||||
CREDENTIALS: 'include' | 'omit' | 'same-origin';
|
||||
TOKEN?: string | Resolver<string>;
|
||||
USERNAME?: string | Resolver<string>;
|
||||
PASSWORD?: string | Resolver<string>;
|
||||
HEADERS?: Headers | Resolver<Headers>;
|
||||
ENCODE_PATH?: (path: string) => string;
|
||||
};
|
||||
|
||||
export const OpenAPI: OpenAPIConfig = {
|
||||
BASE: '',
|
||||
VERSION: '1.0.0',
|
||||
WITH_CREDENTIALS: false,
|
||||
CREDENTIALS: 'include',
|
||||
TOKEN: undefined,
|
||||
USERNAME: undefined,
|
||||
PASSWORD: undefined,
|
||||
HEADERS: undefined,
|
||||
ENCODE_PATH: undefined,
|
||||
};
|
||||
@@ -1,349 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Custom `request.ts` file for OpenAPI code generator.
|
||||
*
|
||||
* Patches the request logic in such a way that we can extract headers from requests.
|
||||
*
|
||||
* Copied from https://github.com/ferdikoomen/openapi-typescript-codegen/issues/829#issuecomment-1228224477
|
||||
*/
|
||||
|
||||
import axios from 'axios';
|
||||
import type { AxiosError, AxiosRequestConfig, AxiosResponse } from 'axios';
|
||||
import FormData from 'form-data';
|
||||
|
||||
import { ApiError } from './ApiError';
|
||||
import type { ApiRequestOptions } from './ApiRequestOptions';
|
||||
import type { ApiResult } from './ApiResult';
|
||||
import { CancelablePromise } from './CancelablePromise';
|
||||
import type { OnCancel } from './CancelablePromise';
|
||||
import type { OpenAPIConfig } from './OpenAPI';
|
||||
|
||||
export const HEADERS = Symbol('HEADERS');
|
||||
|
||||
const isDefined = <T>(
|
||||
value: T | null | undefined
|
||||
): value is Exclude<T, null | undefined> => {
|
||||
return value !== undefined && value !== null;
|
||||
};
|
||||
|
||||
const isString = (value: any): value is string => {
|
||||
return typeof value === 'string';
|
||||
};
|
||||
|
||||
const isStringWithValue = (value: any): value is string => {
|
||||
return isString(value) && value !== '';
|
||||
};
|
||||
|
||||
const isBlob = (value: any): value is Blob => {
|
||||
return (
|
||||
typeof value === 'object' &&
|
||||
typeof value.type === 'string' &&
|
||||
typeof value.stream === 'function' &&
|
||||
typeof value.arrayBuffer === 'function' &&
|
||||
typeof value.constructor === 'function' &&
|
||||
typeof value.constructor.name === 'string' &&
|
||||
/^(Blob|File)$/.test(value.constructor.name) &&
|
||||
/^(Blob|File)$/.test(value[Symbol.toStringTag])
|
||||
);
|
||||
};
|
||||
|
||||
const isFormData = (value: any): value is FormData => {
|
||||
return value instanceof FormData;
|
||||
};
|
||||
|
||||
const isSuccess = (status: number): boolean => {
|
||||
return status >= 200 && status < 300;
|
||||
};
|
||||
|
||||
const base64 = (str: string): string => {
|
||||
try {
|
||||
return btoa(str);
|
||||
} catch (err) {
|
||||
// @ts-ignore
|
||||
return Buffer.from(str).toString('base64');
|
||||
}
|
||||
};
|
||||
|
||||
const getQueryString = (params: Record<string, any>): string => {
|
||||
const qs: string[] = [];
|
||||
|
||||
const append = (key: string, value: any) => {
|
||||
qs.push(`${encodeURIComponent(key)}=${encodeURIComponent(String(value))}`);
|
||||
};
|
||||
|
||||
const process = (key: string, value: any) => {
|
||||
if (isDefined(value)) {
|
||||
if (Array.isArray(value)) {
|
||||
value.forEach((v) => {
|
||||
process(key, v);
|
||||
});
|
||||
} else if (typeof value === 'object') {
|
||||
Object.entries(value).forEach(([k, v]) => {
|
||||
process(`${key}[${k}]`, v);
|
||||
});
|
||||
} else {
|
||||
append(key, value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Object.entries(params).forEach(([key, value]) => {
|
||||
process(key, value);
|
||||
});
|
||||
|
||||
if (qs.length > 0) {
|
||||
return `?${qs.join('&')}`;
|
||||
}
|
||||
|
||||
return '';
|
||||
};
|
||||
|
||||
const getUrl = (config: OpenAPIConfig, options: ApiRequestOptions): string => {
|
||||
const encoder = config.ENCODE_PATH || encodeURI;
|
||||
|
||||
const path = options.url
|
||||
.replace('{api-version}', config.VERSION)
|
||||
.replace(/{(.*?)}/g, (substring: string, group: string) => {
|
||||
if (options.path?.hasOwnProperty(group)) {
|
||||
return encoder(String(options.path[group]));
|
||||
}
|
||||
return substring;
|
||||
});
|
||||
|
||||
const url = `${config.BASE}${path}`;
|
||||
if (options.query) {
|
||||
return `${url}${getQueryString(options.query)}`;
|
||||
}
|
||||
return url;
|
||||
};
|
||||
|
||||
const getFormData = (options: ApiRequestOptions): FormData | undefined => {
|
||||
if (options.formData) {
|
||||
const formData = new FormData();
|
||||
|
||||
const process = (key: string, value: any) => {
|
||||
if (isString(value) || isBlob(value)) {
|
||||
formData.append(key, value);
|
||||
} else {
|
||||
formData.append(key, JSON.stringify(value));
|
||||
}
|
||||
};
|
||||
|
||||
Object.entries(options.formData)
|
||||
.filter(([_, value]) => isDefined(value))
|
||||
.forEach(([key, value]) => {
|
||||
if (Array.isArray(value)) {
|
||||
value.forEach((v) => process(key, v));
|
||||
} else {
|
||||
process(key, value);
|
||||
}
|
||||
});
|
||||
|
||||
return formData;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
type Resolver<T> = (options: ApiRequestOptions) => Promise<T>;
|
||||
|
||||
const resolve = async <T>(
|
||||
options: ApiRequestOptions,
|
||||
resolver?: T | Resolver<T>
|
||||
): Promise<T | undefined> => {
|
||||
if (typeof resolver === 'function') {
|
||||
return (resolver as Resolver<T>)(options);
|
||||
}
|
||||
return resolver;
|
||||
};
|
||||
|
||||
const getHeaders = async (
|
||||
config: OpenAPIConfig,
|
||||
options: ApiRequestOptions,
|
||||
formData?: FormData
|
||||
): Promise<Record<string, string>> => {
|
||||
const token = await resolve(options, config.TOKEN);
|
||||
const username = await resolve(options, config.USERNAME);
|
||||
const password = await resolve(options, config.PASSWORD);
|
||||
const additionalHeaders = await resolve(options, config.HEADERS);
|
||||
const formHeaders =
|
||||
(typeof formData?.getHeaders === 'function' && formData?.getHeaders()) ||
|
||||
{};
|
||||
|
||||
const headers = Object.entries({
|
||||
Accept: 'application/json',
|
||||
...additionalHeaders,
|
||||
...options.headers,
|
||||
...formHeaders,
|
||||
})
|
||||
.filter(([_, value]) => isDefined(value))
|
||||
.reduce(
|
||||
(headers, [key, value]) => ({
|
||||
...headers,
|
||||
[key]: String(value),
|
||||
}),
|
||||
{} as Record<string, string>
|
||||
);
|
||||
|
||||
if (isStringWithValue(token)) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
if (isStringWithValue(username) && isStringWithValue(password)) {
|
||||
const credentials = base64(`${username}:${password}`);
|
||||
headers['Authorization'] = `Basic ${credentials}`;
|
||||
}
|
||||
|
||||
if (options.body) {
|
||||
if (options.mediaType) {
|
||||
headers['Content-Type'] = options.mediaType;
|
||||
} else if (isBlob(options.body)) {
|
||||
headers['Content-Type'] = options.body.type || 'application/octet-stream';
|
||||
} else if (isString(options.body)) {
|
||||
headers['Content-Type'] = 'text/plain';
|
||||
} else if (!isFormData(options.body)) {
|
||||
headers['Content-Type'] = 'application/json';
|
||||
}
|
||||
}
|
||||
|
||||
return headers;
|
||||
};
|
||||
|
||||
const getRequestBody = (options: ApiRequestOptions): any => {
|
||||
if (options.body) {
|
||||
return options.body;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const sendRequest = async <T>(
|
||||
config: OpenAPIConfig,
|
||||
options: ApiRequestOptions,
|
||||
url: string,
|
||||
body: any,
|
||||
formData: FormData | undefined,
|
||||
headers: Record<string, string>,
|
||||
onCancel: OnCancel
|
||||
): Promise<AxiosResponse<T>> => {
|
||||
const source = axios.CancelToken.source();
|
||||
|
||||
const requestConfig: AxiosRequestConfig = {
|
||||
url,
|
||||
headers,
|
||||
data: body ?? formData,
|
||||
method: options.method,
|
||||
withCredentials: config.WITH_CREDENTIALS,
|
||||
cancelToken: source.token,
|
||||
};
|
||||
|
||||
onCancel(() => source.cancel('The user aborted a request.'));
|
||||
|
||||
try {
|
||||
return await axios.request(requestConfig);
|
||||
} catch (error) {
|
||||
const axiosError = error as AxiosError<T>;
|
||||
if (axiosError.response) {
|
||||
return axiosError.response;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const getResponseHeader = (
|
||||
response: AxiosResponse<any>,
|
||||
responseHeader?: string
|
||||
): string | undefined => {
|
||||
if (responseHeader) {
|
||||
const content = response.headers[responseHeader];
|
||||
if (isString(content)) {
|
||||
return content;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const getResponseBody = (response: AxiosResponse<any>): any => {
|
||||
if (response.status !== 204) {
|
||||
return response.data;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const catchErrorCodes = (
|
||||
options: ApiRequestOptions,
|
||||
result: ApiResult
|
||||
): void => {
|
||||
const errors: Record<number, string> = {
|
||||
400: 'Bad Request',
|
||||
401: 'Unauthorized',
|
||||
403: 'Forbidden',
|
||||
404: 'Not Found',
|
||||
500: 'Internal Server Error',
|
||||
502: 'Bad Gateway',
|
||||
503: 'Service Unavailable',
|
||||
...options.errors,
|
||||
};
|
||||
|
||||
const error = errors[result.status];
|
||||
if (error) {
|
||||
throw new ApiError(options, result, error);
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
throw new ApiError(options, result, 'Generic Error');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Request method
|
||||
* @param config The OpenAPI configuration object
|
||||
* @param options The request options from the service
|
||||
* @returns CancelablePromise<T>
|
||||
* @throws ApiError
|
||||
*/
|
||||
export const request = <T>(
|
||||
config: OpenAPIConfig,
|
||||
options: ApiRequestOptions
|
||||
): CancelablePromise<T> => {
|
||||
return new CancelablePromise(async (resolve, reject, onCancel) => {
|
||||
try {
|
||||
const url = getUrl(config, options);
|
||||
const formData = getFormData(options);
|
||||
const body = getRequestBody(options);
|
||||
const headers = await getHeaders(config, options, formData);
|
||||
|
||||
if (!onCancel.isCancelled) {
|
||||
const response = await sendRequest<T>(
|
||||
config,
|
||||
options,
|
||||
url,
|
||||
body,
|
||||
formData,
|
||||
headers,
|
||||
onCancel
|
||||
);
|
||||
const responseBody = getResponseBody(response);
|
||||
const responseHeader = getResponseHeader(
|
||||
response,
|
||||
options.responseHeader
|
||||
);
|
||||
|
||||
const result: ApiResult = {
|
||||
url,
|
||||
ok: isSuccess(response.status),
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
body: responseHeader ?? responseBody,
|
||||
};
|
||||
|
||||
catchErrorCodes(options, result);
|
||||
|
||||
resolve({ ...result.body, [HEADERS]: response.headers });
|
||||
}
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
};
|
||||
@@ -1,84 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export { ApiError } from './core/ApiError';
|
||||
export { CancelablePromise, CancelError } from './core/CancelablePromise';
|
||||
export { OpenAPI } from './core/OpenAPI';
|
||||
export type { OpenAPIConfig } from './core/OpenAPI';
|
||||
|
||||
export type { BlurInvocation } from './models/BlurInvocation';
|
||||
export type { Body_upload_image } from './models/Body_upload_image';
|
||||
export type { CollectInvocation } from './models/CollectInvocation';
|
||||
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
|
||||
export type { CropImageInvocation } from './models/CropImageInvocation';
|
||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||
export type { Edge } from './models/Edge';
|
||||
export type { EdgeConnection } from './models/EdgeConnection';
|
||||
export type { Graph } from './models/Graph';
|
||||
export type { GraphExecutionState } from './models/GraphExecutionState';
|
||||
export type { GraphInvocation } from './models/GraphInvocation';
|
||||
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
||||
export type { HTTPValidationError } from './models/HTTPValidationError';
|
||||
export type { ImageField } from './models/ImageField';
|
||||
export type { ImageMetadata } from './models/ImageMetadata';
|
||||
export type { ImageOutput } from './models/ImageOutput';
|
||||
export type { ImageResponse } from './models/ImageResponse';
|
||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
||||
export type { ImageType } from './models/ImageType';
|
||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
|
||||
export type { IterateInvocation } from './models/IterateInvocation';
|
||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||
export type { LerpInvocation } from './models/LerpInvocation';
|
||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||
export type { PaginatedResults_ImageResponse_ } from './models/PaginatedResults_ImageResponse_';
|
||||
export type { PasteImageInvocation } from './models/PasteImageInvocation';
|
||||
export type { PromptOutput } from './models/PromptOutput';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
export type { TextToImageInvocation } from './models/TextToImageInvocation';
|
||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||
export type { ValidationError } from './models/ValidationError';
|
||||
|
||||
export { $BlurInvocation } from './schemas/$BlurInvocation';
|
||||
export { $Body_upload_image } from './schemas/$Body_upload_image';
|
||||
export { $CollectInvocation } from './schemas/$CollectInvocation';
|
||||
export { $CollectInvocationOutput } from './schemas/$CollectInvocationOutput';
|
||||
export { $CropImageInvocation } from './schemas/$CropImageInvocation';
|
||||
export { $CvInpaintInvocation } from './schemas/$CvInpaintInvocation';
|
||||
export { $Edge } from './schemas/$Edge';
|
||||
export { $EdgeConnection } from './schemas/$EdgeConnection';
|
||||
export { $Graph } from './schemas/$Graph';
|
||||
export { $GraphExecutionState } from './schemas/$GraphExecutionState';
|
||||
export { $GraphInvocation } from './schemas/$GraphInvocation';
|
||||
export { $GraphInvocationOutput } from './schemas/$GraphInvocationOutput';
|
||||
export { $HTTPValidationError } from './schemas/$HTTPValidationError';
|
||||
export { $ImageField } from './schemas/$ImageField';
|
||||
export { $ImageMetadata } from './schemas/$ImageMetadata';
|
||||
export { $ImageOutput } from './schemas/$ImageOutput';
|
||||
export { $ImageResponse } from './schemas/$ImageResponse';
|
||||
export { $ImageToImageInvocation } from './schemas/$ImageToImageInvocation';
|
||||
export { $ImageType } from './schemas/$ImageType';
|
||||
export { $InpaintInvocation } from './schemas/$InpaintInvocation';
|
||||
export { $InverseLerpInvocation } from './schemas/$InverseLerpInvocation';
|
||||
export { $IterateInvocation } from './schemas/$IterateInvocation';
|
||||
export { $IterateInvocationOutput } from './schemas/$IterateInvocationOutput';
|
||||
export { $LerpInvocation } from './schemas/$LerpInvocation';
|
||||
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
|
||||
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
|
||||
export { $MaskOutput } from './schemas/$MaskOutput';
|
||||
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';
|
||||
export { $PaginatedResults_ImageResponse_ } from './schemas/$PaginatedResults_ImageResponse_';
|
||||
export { $PasteImageInvocation } from './schemas/$PasteImageInvocation';
|
||||
export { $PromptOutput } from './schemas/$PromptOutput';
|
||||
export { $RestoreFaceInvocation } from './schemas/$RestoreFaceInvocation';
|
||||
export { $ShowImageInvocation } from './schemas/$ShowImageInvocation';
|
||||
export { $TextToImageInvocation } from './schemas/$TextToImageInvocation';
|
||||
export { $UpscaleInvocation } from './schemas/$UpscaleInvocation';
|
||||
export { $ValidationError } from './schemas/$ValidationError';
|
||||
|
||||
export { ImagesService } from './services/ImagesService';
|
||||
export { SessionsService } from './services/SessionsService';
|
||||
@@ -1,29 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Blurs an image
|
||||
*/
|
||||
export type BlurInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'blur';
|
||||
/**
|
||||
* The image to blur
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The blur radius
|
||||
*/
|
||||
radius?: number;
|
||||
/**
|
||||
* The type of blur
|
||||
*/
|
||||
blur_type?: 'gaussian' | 'box';
|
||||
};
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_upload_image = {
|
||||
file: Blob;
|
||||
};
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Collects values into a collection
|
||||
*/
|
||||
export type CollectInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'collect';
|
||||
/**
|
||||
* The item to collect (all inputs must be of the same type)
|
||||
*/
|
||||
item?: any;
|
||||
/**
|
||||
* The collection, will be provided on execution
|
||||
*/
|
||||
collection?: Array<any>;
|
||||
};
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Base class for all invocation outputs
|
||||
*/
|
||||
export type CollectInvocationOutput = {
|
||||
type: 'collect_output';
|
||||
/**
|
||||
* The collection of input items
|
||||
*/
|
||||
collection: Array<any>;
|
||||
};
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Crops an image to a specified box. The box can be outside of the image.
|
||||
*/
|
||||
export type CropImageInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'crop';
|
||||
/**
|
||||
* The image to crop
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The left x coordinate of the crop rectangle
|
||||
*/
|
||||
'x'?: number;
|
||||
/**
|
||||
* The top y coordinate of the crop rectangle
|
||||
*/
|
||||
'y'?: number;
|
||||
/**
|
||||
* The width of the crop rectangle
|
||||
*/
|
||||
width?: number;
|
||||
/**
|
||||
* The height of the crop rectangle
|
||||
*/
|
||||
height?: number;
|
||||
};
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Simple inpaint using opencv.
|
||||
*/
|
||||
export type CvInpaintInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'cv_inpaint';
|
||||
/**
|
||||
* The image to inpaint
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The mask to use when inpainting
|
||||
*/
|
||||
mask?: ImageField;
|
||||
};
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { EdgeConnection } from './EdgeConnection';
|
||||
|
||||
export type Edge = {
|
||||
/**
|
||||
* The connection for the edge's from node and field
|
||||
*/
|
||||
source: EdgeConnection;
|
||||
/**
|
||||
* The connection for the edge's to node and field
|
||||
*/
|
||||
destination: EdgeConnection;
|
||||
};
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type EdgeConnection = {
|
||||
/**
|
||||
* The id of the node for this edge connection
|
||||
*/
|
||||
node_id: string;
|
||||
/**
|
||||
* The field for this connection
|
||||
*/
|
||||
field: string;
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user