mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 05:28:10 -05:00
Compare commits
31 Commits
feat/ui/me
...
pre-nodes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdad62e88b | ||
|
|
955c81acef | ||
|
|
e1058f3416 | ||
|
|
edf16a253d | ||
|
|
46f5ef4100 | ||
|
|
b843255236 | ||
|
|
3a968e5072 | ||
|
|
fd80e84ea6 | ||
|
|
4824237a98 | ||
|
|
2c9a05eb59 | ||
|
|
ecb5bdaf7e | ||
|
|
f6cdff2c5b | ||
|
|
63d10027a4 | ||
|
|
ef0773b8a3 | ||
|
|
3daaddf15b | ||
|
|
570c3fe690 | ||
|
|
cbd1a7263a | ||
|
|
7fc5fbd4ce | ||
|
|
6f6de402ad | ||
|
|
281662a6e1 | ||
|
|
50eb02f68b | ||
|
|
d73f3adc43 | ||
|
|
116107f464 | ||
|
|
da44bb1707 | ||
|
|
f43aed677e | ||
|
|
0d051aaae2 | ||
|
|
e4e48ff995 | ||
|
|
442a6bffa4 | ||
|
|
23d65e7162 | ||
|
|
024fd54d0b | ||
|
|
c44c19e911 |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -1,16 +1,16 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
@@ -22,11 +22,11 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
|
||||
@@ -84,7 +84,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generaiton models.
|
||||
select a set of starting image generation models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
@@ -148,6 +148,11 @@ not supported.
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
|
||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file (usually
|
||||
`.invokeai` in your home directory). You can change the default at any
|
||||
`invokeai.init` in your home directory). You can change the default at any
|
||||
time by opening this file in a text editor and commenting or
|
||||
uncommenting the line `--nsfw_checker`.
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@@ -69,6 +71,9 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@@ -76,6 +81,8 @@ class ApiDependencies:
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import io
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Query, Request, UploadFile
|
||||
@@ -10,7 +8,6 @@ from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
@@ -43,47 +40,32 @@ async def get_thumbnail(
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
status_code=201
|
||||
)
|
||||
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
|
||||
async def upload_image(file: UploadFile, request: Request):
|
||||
if not file.content_type.startswith("image"):
|
||||
return Response(status_code=415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
|
||||
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
created=int(os.path.getctime(image_path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata
|
||||
),
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = request.url_for(
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
|
||||
return res
|
||||
},
|
||||
)
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
|
||||
@@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..models.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@@ -36,30 +65,26 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@@ -67,17 +92,38 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
||||
@@ -13,17 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@@ -98,14 +153,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@@ -158,6 +213,9 @@ def invoke_cli():
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@@ -165,9 +223,12 @@ def invoke_cli():
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
@@ -185,11 +246,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@@ -205,8 +267,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@@ -217,6 +295,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@@ -229,7 +308,7 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@@ -242,7 +321,7 @@ def invoke_cli():
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@@ -256,12 +335,14 @@ def invoke_cli():
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -270,10 +351,10 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
@@ -285,7 +366,7 @@ def invoke_cli():
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
@@ -56,9 +56,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -9,9 +9,9 @@ from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.invocations.util.get_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..models.exceptions import CanceledException
|
||||
@@ -76,7 +76,6 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model_name = model["model_name"]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
@@ -96,22 +95,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
invocation = graph_execution_state.execution_graph.get_node(self.id)
|
||||
|
||||
metadata = {
|
||||
"session": context.graph_execution_state_id,
|
||||
"source_id": source_id,
|
||||
"invocation": invocation.dict()
|
||||
}
|
||||
|
||||
context.services.images.save(image_type, image_name, generate_output.image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=generate_output.image
|
||||
context.services.images.save(image_type, image_name, generate_output.image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
@@ -158,7 +144,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model = model["model_name"]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
@@ -183,11 +168,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
@@ -235,8 +218,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model = model["model_name"]
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
@@ -261,9 +243,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -9,12 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
@@ -27,70 +22,51 @@ class PILInvocationConfig(BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
image_field = ImageField(image_name=image_name, image_type=image_type)
|
||||
|
||||
return ImageOutput(image=image_field, width=image.width, height=image.height)
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
'required': [
|
||||
'type',
|
||||
'mask',
|
||||
]
|
||||
}
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
#fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# # TODO: this isn't really necessary anymore
|
||||
# class LoadImageInvocation(BaseInvocation):
|
||||
# """Load an image from a filename and provide it as output."""
|
||||
# #fmt: off
|
||||
# type: Literal["load_image"] = "load_image"
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
#fmt: on
|
||||
|
||||
# # Inputs
|
||||
# image_type: ImageType = Field(description="The type of the image")
|
||||
# image_name: str = Field(description="The name of the image")
|
||||
# #fmt: on
|
||||
|
||||
# def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# return ImageOutput(
|
||||
# image_type=self.image_type,
|
||||
# image_name=self.image_name,
|
||||
# image=result_image
|
||||
# )
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
@@ -110,17 +86,16 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image.image_type,
|
||||
image_name=self.image.image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=self.image.image_type, image_name=self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
@@ -129,7 +104,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
@@ -145,16 +120,15 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_crop, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image_crop
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
@@ -163,7 +137,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
@@ -196,22 +170,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, new_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=new_image
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
@@ -226,22 +199,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask, self.dict())
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@@ -258,23 +231,22 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, blur_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
@@ -290,24 +262,23 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, lerp_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@@ -327,7 +298,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.invocations.util.get_model import choose_model
|
||||
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
@@ -18,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
@@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
@@ -166,7 +171,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
@@ -180,7 +185,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
|
||||
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
@@ -190,7 +195,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
@@ -287,7 +292,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
@@ -350,9 +355,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
context.services.images.save(image_type, image_name, image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
@@ -6,7 +6,7 @@ from pydantic import Field
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
@@ -44,9 +44,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0], self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import Field
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
@@ -49,9 +49,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0], self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
return model_manager.get_model(model_name)
|
||||
else:
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
|
||||
return model
|
||||
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
|
||||
return model_manager.get_model()
|
||||
@@ -1,26 +1,11 @@
|
||||
from typing import Any, Optional, Dict
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvokeAIMetadata(BaseModel):
|
||||
"""An image's InvokeAI-specific metadata"""
|
||||
|
||||
session: Optional[str] = Field(description="The session that generated this image")
|
||||
source_id: Optional[str] = Field(
|
||||
description="The source id of the invocation that generated this image"
|
||||
)
|
||||
# TODO: figure out metadata
|
||||
invocation: Optional[Dict[str, Any]] = Field(
|
||||
default={}, description="The prepared invocation that generated this image"
|
||||
)
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's general metadata"""
|
||||
"""An image's metadata"""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
timestamp: float = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
default={}, description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
# TODO: figure out metadata
|
||||
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
||||
|
||||
56
invokeai/app/services/default_graphs.py
Normal file
56
invokeai/app/services/default_graphs.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -25,8 +25,7 @@ class EventServiceBase:
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_dict: dict,
|
||||
source_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
@@ -36,8 +35,7 @@ class EventServiceBase:
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
@@ -45,43 +43,40 @@ class EventServiceBase:
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
|
||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
invocation_id=invocation_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
|
||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
invocation_id=invocation_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
|
||||
self, graph_execution_state_id: str, invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
invocation_id=invocation_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
@@ -17,7 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -283,7 +281,8 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@@ -354,7 +353,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@@ -362,54 +361,53 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@@ -733,7 +731,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@@ -750,9 +748,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(
|
||||
description="The id of the execution state", default_factory=uuid.uuid4
|
||||
)
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
@@ -858,7 +854,8 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@@ -946,11 +943,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@@ -962,7 +959,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@@ -1098,7 +1095,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@@ -1144,4 +1143,52 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import json
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from pydantic import BaseModel, Json
|
||||
from pydantic import BaseModel
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
@@ -43,7 +42,7 @@ class ImageStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -101,8 +100,6 @@ class DiskImageStorage(ImageStorageBase):
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
@@ -112,10 +109,9 @@ class DiskImageStorage(ImageStorageBase):
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
timestamp=os.path.getctime(path),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -154,11 +150,10 @@ class DiskImageStorage(ImageStorageBase):
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
|
||||
print(metadata)
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, metadata
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
save_thumbnail(
|
||||
image=image,
|
||||
@@ -167,7 +162,6 @@ class DiskImageStorage(ImageStorageBase):
|
||||
)
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
return image_path
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
|
||||
@@ -19,6 +19,7 @@ class InvocationServices:
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
@@ -29,6 +30,7 @@ class InvocationServices:
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
@@ -38,6 +40,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
|
||||
@@ -43,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# get the source node to provide to cliepnts (the prepared node is not as useful)
|
||||
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id
|
||||
invocation_id=invocation.id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
@@ -79,8 +75,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
invocation_id=invocation.id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
@@ -104,8 +99,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
invocation_id=invocation.id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
from sqlalchemy import create_engine, String, TEXT, Engine, select
|
||||
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_engine: Engine
|
||||
# _table: ??? # TODO: figure out how to type this
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
@@ -27,79 +25,72 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._engine = create_engine(f"sqlite+pysqlite:///{self._filename}")
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
# dynamically create the ORM model class to avoid name collisions
|
||||
|
||||
# cannot access `self.__orig_class__` in `__init__` or `__new__` so
|
||||
# format the table name into the class name
|
||||
pascal_table_name = self._table_name.replace("_", " ").title()
|
||||
pascal_table_name = pascal_table_name.replace(" ", "")
|
||||
|
||||
table_dict = dict(
|
||||
__tablename__=self._table_name,
|
||||
id=mapped_column(String, primary_key=True),
|
||||
item=mapped_column(TEXT, nullable=False),
|
||||
)
|
||||
|
||||
self._table = type(pascal_table_name, (Base,), table_dict)
|
||||
|
||||
Base.metadata.create_all(self._engine)
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
session = Session(self._engine)
|
||||
|
||||
item_id = str(getattr(item, self._id_field))
|
||||
new_item = self._table(id=item_id, item=item.json())
|
||||
|
||||
session.merge(new_item)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
session = Session(self._engine)
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
|
||||
item = session.get(self._table, str(id))
|
||||
|
||||
session.close()
|
||||
|
||||
if not item:
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(item.item)
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
session = Session(self._engine)
|
||||
|
||||
item = session.get(self._table, id)
|
||||
session.delete(item)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
session = Session(self._engine)
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
stmt = select(self._table.item).limit(per_page).offset(page * per_page)
|
||||
result = session.execute(stmt)
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
count = session.query(self._table.item).count()
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
@@ -110,19 +101,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
session = Session(self._engine)
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
stmt = (
|
||||
session.query(self._table)
|
||||
.where(self._table.item.like(f"%{query}%"))
|
||||
.limit(per_page)
|
||||
.offset(page * per_page)
|
||||
)
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
result = session.execute(stmt)
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0].item), result))
|
||||
count = session.query(self._table.item).count()
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from re import S
|
||||
import torch
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
@@ -21,18 +20,12 @@ def fast_latents_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_id = graph_execution_state.prepared_source_mapping[id]
|
||||
|
||||
invocation = graph_execution_state.execution_graph.get_node(id)
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
progress_image={"width": width, "height": height, "dataURL": dataURL},
|
||||
step=step,
|
||||
total_steps=steps,
|
||||
context.graph_execution_state_id,
|
||||
id,
|
||||
{"width": width, "height": height, "dataURL": dataURL},
|
||||
step,
|
||||
steps,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class PngWriter:
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream", dream_prompt)
|
||||
if metadata:
|
||||
info.add_text("invokeai", json.dumps(metadata))
|
||||
info.add_text("sd-metadata", json.dumps(metadata))
|
||||
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
else:
|
||||
elif Globals.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
@@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -6,5 +6,3 @@ stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
||||
@@ -3,8 +3,4 @@ dist/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
# InvokeAI Web UI
|
||||
|
||||
- [InvokeAI Web UI](#invokeai-web-ui)
|
||||
- [Stack](#stack)
|
||||
- [Contributing](#contributing)
|
||||
- [Dev Environment](#dev-environment)
|
||||
- [Production builds](#production-builds)
|
||||
|
||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||
|
||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||
|
||||
## Stack
|
||||
## Details
|
||||
|
||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||
|
||||
@@ -38,7 +32,7 @@ Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
||||
|
||||
### Production builds
|
||||
|
||||
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ar.json
vendored
1
invokeai/frontend/web/dist/locales/ar.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "داكن",
|
||||
"lightTheme": "فاتح",
|
||||
"greenTheme": "أخضر",
|
||||
"text2img": "نص إلى صورة",
|
||||
"img2img": "صورة إلى صورة",
|
||||
"unifiedCanvas": "لوحة موحدة",
|
||||
"nodes": "عقد",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/de.json
vendored
1
invokeai/frontend/web/dist/locales/de.json
vendored
@@ -7,7 +7,6 @@
|
||||
"darkTheme": "Dunkel",
|
||||
"lightTheme": "Hell",
|
||||
"greenTheme": "Grün",
|
||||
"text2img": "Text zu Bild",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten",
|
||||
"langGerman": "Deutsch",
|
||||
|
||||
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
||||
12
invokeai/frontend/web/dist/locales/es.json
vendored
12
invokeai/frontend/web/dist/locales/es.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Oscuro",
|
||||
"lightTheme": "Claro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto a Imagen",
|
||||
"img2img": "Imagen a Imagen",
|
||||
"unifiedCanvas": "Lienzo Unificado",
|
||||
"nodes": "Nodos",
|
||||
@@ -70,7 +69,11 @@
|
||||
"langHebrew": "Hebreo",
|
||||
"pinOptionsPanel": "Pin del panel de opciones",
|
||||
"loading": "Cargando",
|
||||
"loadingInvokeAI": "Cargando invocar a la IA"
|
||||
"loadingInvokeAI": "Cargando invocar a la IA",
|
||||
"postprocessing": "Tratamiento posterior",
|
||||
"txt2img": "De texto a imagen",
|
||||
"accept": "Aceptar",
|
||||
"cancel": "Cancelar"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generaciones",
|
||||
@@ -404,7 +407,8 @@
|
||||
"none": "ninguno",
|
||||
"pickModelType": "Elige el tipo de modelo",
|
||||
"v2_768": "v2 (768px)",
|
||||
"addDifference": "Añadir una diferencia"
|
||||
"addDifference": "Añadir una diferencia",
|
||||
"scanForModels": "Buscar modelos"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Imágenes",
|
||||
@@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
||||
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||
"clearHistory": "Limpiar historial",
|
||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||
|
||||
25
invokeai/frontend/web/dist/locales/fr.json
vendored
25
invokeai/frontend/web/dist/locales/fr.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Sombre",
|
||||
"lightTheme": "Clair",
|
||||
"greenTheme": "Vert",
|
||||
"text2img": "Texte en image",
|
||||
"img2img": "Image en image",
|
||||
"unifiedCanvas": "Canvas unifié",
|
||||
"nodes": "Nœuds",
|
||||
@@ -47,7 +46,19 @@
|
||||
"statusLoadingModel": "Chargement du modèle",
|
||||
"statusModelChanged": "Modèle changé",
|
||||
"discordLabel": "Discord",
|
||||
"githubLabel": "Github"
|
||||
"githubLabel": "Github",
|
||||
"accept": "Accepter",
|
||||
"statusMergingModels": "Mélange des modèles",
|
||||
"loadingInvokeAI": "Chargement de Invoke AI",
|
||||
"cancel": "Annuler",
|
||||
"langEnglish": "Anglais",
|
||||
"statusConvertingModel": "Conversion du modèle",
|
||||
"statusModelConverted": "Modèle converti",
|
||||
"loading": "Chargement",
|
||||
"pinOptionsPanel": "Épingler la page d'options",
|
||||
"statusMergedModels": "Modèles mélangés",
|
||||
"txt2img": "Texte vers image",
|
||||
"postprocessing": "Post-Traitement"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Générations",
|
||||
@@ -518,5 +529,15 @@
|
||||
"betaDarkenOutside": "Assombrir à l'extérieur",
|
||||
"betaLimitToBox": "Limiter à la boîte",
|
||||
"betaPreserveMasked": "Conserver masqué"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Charger une image",
|
||||
"reset": "Réinitialiser",
|
||||
"nextImage": "Image suivante",
|
||||
"previousImage": "Image précédente",
|
||||
"useThisParameter": "Utiliser ce paramètre",
|
||||
"zoomIn": "Zoom avant",
|
||||
"zoomOut": "Zoom arrière",
|
||||
"showOptionsPanel": "Montrer la page d'options"
|
||||
}
|
||||
}
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/he.json
vendored
1
invokeai/frontend/web/dist/locales/he.json
vendored
@@ -125,7 +125,6 @@
|
||||
"langSimplifiedChinese": "סינית",
|
||||
"langUkranian": "אוקראינית",
|
||||
"langSpanish": "ספרדית",
|
||||
"text2img": "טקסט לתמונה",
|
||||
"img2img": "תמונה לתמונה",
|
||||
"unifiedCanvas": "קנבס מאוחד",
|
||||
"nodes": "צמתים",
|
||||
|
||||
14
invokeai/frontend/web/dist/locales/it.json
vendored
14
invokeai/frontend/web/dist/locales/it.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Scuro",
|
||||
"lightTheme": "Chiaro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Testo a Immagine",
|
||||
"img2img": "Immagine a Immagine",
|
||||
"unifiedCanvas": "Tela unificata",
|
||||
"nodes": "Nodi",
|
||||
@@ -70,7 +69,11 @@
|
||||
"loading": "Caricamento in corso",
|
||||
"oceanTheme": "Oceano",
|
||||
"langHebrew": "Ebraico",
|
||||
"loadingInvokeAI": "Caricamento Invoke AI"
|
||||
"loadingInvokeAI": "Caricamento Invoke AI",
|
||||
"postprocessing": "Post Elaborazione",
|
||||
"txt2img": "Testo a Immagine",
|
||||
"accept": "Accetta",
|
||||
"cancel": "Annulla"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@@ -404,7 +407,8 @@
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "niente",
|
||||
"addDifference": "Aggiungi differenza",
|
||||
"pickModelType": "Scegli il tipo di modello"
|
||||
"pickModelType": "Scegli il tipo di modello",
|
||||
"scanForModels": "Cerca modelli"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
||||
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
||||
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
||||
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
|
||||
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
|
||||
"clearCanvasHistory": "Cancella cronologia Tela",
|
||||
"clearHistory": "Cancella la cronologia",
|
||||
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
||||
@@ -612,7 +616,7 @@
|
||||
"copyMetadataJson": "Copia i metadati JSON",
|
||||
"exitViewer": "Esci dal visualizzatore",
|
||||
"zoomIn": "Zoom avanti",
|
||||
"zoomOut": "Zoom Indietro",
|
||||
"zoomOut": "Zoom indietro",
|
||||
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
||||
"rotateClockwise": "Ruotare in senso orario",
|
||||
"flipHorizontally": "Capovolgi orizzontalmente",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ko.json
vendored
1
invokeai/frontend/web/dist/locales/ko.json
vendored
@@ -11,7 +11,6 @@
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
"text2img": "텍스트->이미지",
|
||||
"unifiedCanvas": "통합 캔버스",
|
||||
"langFrench": "Français",
|
||||
"langGerman": "Deutsch",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/nl.json
vendored
1
invokeai/frontend/web/dist/locales/nl.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Donker",
|
||||
"lightTheme": "Licht",
|
||||
"greenTheme": "Groen",
|
||||
"text2img": "Tekst naar afbeelding",
|
||||
"img2img": "Afbeelding naar afbeelding",
|
||||
"unifiedCanvas": "Centraal canvas",
|
||||
"nodes": "Knooppunten",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/pl.json
vendored
1
invokeai/frontend/web/dist/locales/pl.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Ciemny",
|
||||
"lightTheme": "Jasny",
|
||||
"greenTheme": "Zielony",
|
||||
"text2img": "Tekst na obraz",
|
||||
"img2img": "Obraz na obraz",
|
||||
"unifiedCanvas": "Tryb uniwersalny",
|
||||
"nodes": "Węzły",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/pt.json
vendored
1
invokeai/frontend/web/dist/locales/pt.json
vendored
@@ -20,7 +20,6 @@
|
||||
"langSpanish": "Espanhol",
|
||||
"langRussian": "Русский",
|
||||
"langUkranian": "Украї́нська",
|
||||
"text2img": "Texto para Imagem",
|
||||
"img2img": "Imagem para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nós",
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Noite",
|
||||
"lightTheme": "Dia",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto Para Imagem",
|
||||
"img2img": "Imagem Para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nódulos",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ru.json
vendored
1
invokeai/frontend/web/dist/locales/ru.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темная",
|
||||
"lightTheme": "Светлая",
|
||||
"greenTheme": "Зеленая",
|
||||
"text2img": "Изображение из текста (text2img)",
|
||||
"img2img": "Изображение в изображение (img2img)",
|
||||
"unifiedCanvas": "Универсальный холст",
|
||||
"nodes": "Ноды",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/uk.json
vendored
1
invokeai/frontend/web/dist/locales/uk.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темна",
|
||||
"lightTheme": "Світла",
|
||||
"greenTheme": "Зелена",
|
||||
"text2img": "Зображення із тексту (text2img)",
|
||||
"img2img": "Зображення із зображення (img2img)",
|
||||
"unifiedCanvas": "Універсальне полотно",
|
||||
"nodes": "Вузли",
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "暗色",
|
||||
"lightTheme": "亮色",
|
||||
"greenTheme": "绿色",
|
||||
"text2img": "文字到图像",
|
||||
"img2img": "图像到图像",
|
||||
"unifiedCanvas": "统一画布",
|
||||
"nodes": "节点",
|
||||
|
||||
@@ -33,7 +33,6 @@
|
||||
"langBrPortuguese": "巴西葡萄牙語",
|
||||
"langRussian": "俄語",
|
||||
"langSpanish": "西班牙語",
|
||||
"text2img": "文字到圖像",
|
||||
"unifiedCanvas": "統一畫布"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
# Generated axios API client
|
||||
|
||||
- [Generated axios API client](#generated-axios-api-client)
|
||||
- [Generation](#generation)
|
||||
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
|
||||
- [Generate the API client from JSON](#generate-the-api-client-from-json)
|
||||
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
|
||||
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
|
||||
- [Generate the API client](#generate-the-api-client)
|
||||
- [The generated client](#the-generated-client)
|
||||
- [API client customisation](#api-client-customisation)
|
||||
|
||||
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
|
||||
|
||||
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
|
||||
|
||||
## Generation
|
||||
|
||||
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
|
||||
|
||||
### Generate the API client from the nodes web server
|
||||
|
||||
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
|
||||
|
||||
1. Start the nodes web server.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python scripts/invoke-new.py --web
|
||||
```
|
||||
|
||||
2. Generate the API client.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:web
|
||||
```
|
||||
|
||||
### Generate the API client from JSON
|
||||
|
||||
The JSON can be acquired from the nodes web server, or with a python script.
|
||||
|
||||
#### Getting the JSON from the nodes web server
|
||||
|
||||
Start the nodes web server as described above, then download the file.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
curl http://localhost:9090/openapi.json -o openapi.json
|
||||
```
|
||||
|
||||
#### Getting the JSON with a python script
|
||||
|
||||
Run this python script from the repo root, so it can access the nodes server modules.
|
||||
|
||||
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python invokeai/app/util/generate_openapi_json.py
|
||||
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
|
||||
```
|
||||
|
||||
#### Generate the API client
|
||||
|
||||
Now we can generate the API client from the JSON.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:file
|
||||
```
|
||||
|
||||
## The generated client
|
||||
|
||||
The client will be written to `invokeai/frontend/web/services/api/`:
|
||||
|
||||
- `axios` client
|
||||
- TS types
|
||||
- An easily parseable schema, which we can use to generate UI
|
||||
|
||||
## API client customisation
|
||||
|
||||
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
|
||||
|
||||
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
|
||||
|
||||
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.
|
||||
@@ -1,21 +0,0 @@
|
||||
# Events
|
||||
|
||||
Events via `socket.io`
|
||||
|
||||
## `actions.ts`
|
||||
|
||||
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
|
||||
|
||||
Any reducer (or middleware) can respond to the actions.
|
||||
|
||||
## `middleware.ts`
|
||||
|
||||
Redux middleware for events.
|
||||
|
||||
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
|
||||
|
||||
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
|
||||
|
||||
## `types.ts`
|
||||
|
||||
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.
|
||||
@@ -1,17 +0,0 @@
|
||||
# Node Editor Design
|
||||
|
||||
WIP
|
||||
|
||||
nodes
|
||||
|
||||
everything in `src/features/nodes/`
|
||||
|
||||
have a look at `state.nodes.invocation`
|
||||
|
||||
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
|
||||
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
|
||||
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
|
||||
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
|
||||
- `reactflow` sends changes to nodes/edges to redux
|
||||
- to invoke, `buildNodesGraph()` state, then send this
|
||||
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`
|
||||
@@ -1,29 +0,0 @@
|
||||
# Package Scripts
|
||||
|
||||
WIP walkthrough of `package.json` scripts.
|
||||
|
||||
## `theme` & `theme:watch`
|
||||
|
||||
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
|
||||
|
||||
The CLI essentially monkeypatches Chakra's files in `node_modules`.
|
||||
|
||||
## `postinstall`
|
||||
|
||||
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
|
||||
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
||||
20
invokeai/frontend/web/index.d.ts
vendored
20
invokeai/frontend/web/index.d.ts
vendored
@@ -1,7 +1,6 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export {};
|
||||
|
||||
@@ -65,24 +64,9 @@ declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
|
||||
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||
public constructor(props: StatusIndicatorProps);
|
||||
}
|
||||
|
||||
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||
public constructor(props: ModelSelectProps);
|
||||
}
|
||||
}
|
||||
|
||||
interface InvokeProps extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
}
|
||||
|
||||
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
@@ -90,7 +74,5 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
export = Invoke;
|
||||
|
||||
@@ -5,10 +5,7 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@@ -44,10 +41,9 @@
|
||||
"@chakra-ui/react": "^2.5.1",
|
||||
"@chakra-ui/styled-system": "^2.6.1",
|
||||
"@chakra-ui/theme-tools": "^2.0.16",
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@reduxjs/toolkit": "^1.9.3",
|
||||
"@reduxjs/toolkit": "^1.9.2",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.2.9",
|
||||
@@ -71,9 +67,7 @@
|
||||
"react-redux": "^8.0.5",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-zoom-pan-pinch": "^2.6.1",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"socket.io-client": "^4.6.0",
|
||||
"use-image": "^1.1.0",
|
||||
@@ -89,7 +83,6 @@
|
||||
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||
"@typescript-eslint/parser": "^5.52.0",
|
||||
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||
"axios": "^1.3.4",
|
||||
"babel-plugin-transform-imports": "^2.0.0",
|
||||
"concurrently": "^7.6.0",
|
||||
"eslint": "^8.34.0",
|
||||
@@ -97,17 +90,13 @@
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"form-data": "^4.0.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.1.2",
|
||||
"madge": "^6.0.0",
|
||||
"openapi-types": "^12.1.0",
|
||||
"openapi-typescript-codegen": "^0.23.0",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.4",
|
||||
"rollup-plugin-visualizer": "^5.9.0",
|
||||
"terser": "^5.16.4",
|
||||
"typescript": "4.9.5",
|
||||
"vite": "^4.1.2",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.0.5",
|
||||
|
||||
@@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
@@ -522,10 +524,6 @@
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
"disconnected": "Disconnected from Server",
|
||||
"connected": "Connected to Server",
|
||||
"canceled": "Processing Canceled",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
|
||||
@@ -13,42 +13,16 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
options: {
|
||||
disabledPanels: string[];
|
||||
disabledTabs: InvokeTabName[];
|
||||
shouldTransformUrls?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
const App = (props: Props) => {
|
||||
const App = (props: PropsWithChildren) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||
}, [dispatch, props.options.disabledPanels]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||
}, [dispatch, props.options.disabledTabs]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
|
||||
);
|
||||
}, [dispatch, props.options.shouldTransformUrls]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
|
||||
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@@ -14,8 +14,6 @@
|
||||
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageMetadata, ImageType } from 'services/api';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
|
||||
/**
|
||||
* TODO:
|
||||
@@ -115,7 +113,7 @@ export declare type Metadata = SystemGenerationMetadata & {
|
||||
};
|
||||
|
||||
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||
export declare type _Image = {
|
||||
export declare type Image = {
|
||||
uuid: string;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
@@ -126,23 +124,11 @@ export declare type _Image = {
|
||||
category: GalleryCategory;
|
||||
isBase64?: boolean;
|
||||
dreamPrompt?: 'string';
|
||||
name?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* ResultImage
|
||||
*/
|
||||
export declare type Image = {
|
||||
name: string;
|
||||
type: ImageType;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
metadata: ImageMetadata;
|
||||
};
|
||||
|
||||
// GalleryImages is an array of Image.
|
||||
export declare type GalleryImages = {
|
||||
images: Array<_Image>;
|
||||
images: Array<Image>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -289,7 +275,7 @@ export declare type SystemStatusResponse = SystemStatus;
|
||||
|
||||
export declare type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@@ -310,7 +296,7 @@ export declare type ErrorResponse = {
|
||||
};
|
||||
|
||||
export declare type GalleryImagesResponse = {
|
||||
images: Array<Omit<_Image, 'uuid'>>;
|
||||
images: Array<Omit<Image, 'uuid'>>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
|
||||
@@ -13,13 +13,9 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
export const generateImage = createAction<InvokeTabName>(
|
||||
'socketio/generateImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI._Image>(
|
||||
'socketio/runFacetool'
|
||||
);
|
||||
export const deleteImage = createAction<InvokeAI._Image>(
|
||||
'socketio/deleteImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
|
||||
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
||||
export const requestImages = createAction<GalleryCategory>(
|
||||
'socketio/requestImages'
|
||||
);
|
||||
|
||||
@@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
|
||||
@@ -34,9 +34,8 @@ import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
clearInitialImage,
|
||||
initialImageSelected,
|
||||
setInfillMethod,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setMaskPath,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
@@ -147,8 +146,7 @@ const makeSocketIOListeners = (
|
||||
const activeTabName = tabMap[activeTab];
|
||||
switch (activeTabName) {
|
||||
case 'img2img': {
|
||||
dispatch(initialImageSelected(newImage.uuid));
|
||||
// dispatch(setInitialImage(newImage));
|
||||
dispatch(setInitialImage(newImage));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -264,7 +262,7 @@ const makeSocketIOListeners = (
|
||||
*/
|
||||
|
||||
// Generate a UUID for each image
|
||||
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||
const preparedImages = images.map((image): InvokeAI.Image => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
...image,
|
||||
@@ -336,7 +334,7 @@ const makeSocketIOListeners = (
|
||||
|
||||
if (
|
||||
initialImage === url ||
|
||||
(initialImage as InvokeAI._Image)?.url === url
|
||||
(initialImage as InvokeAI.Image)?.url === url
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ export const socketioMiddleware = () => {
|
||||
path: `${window.location.pathname}socket.io`,
|
||||
});
|
||||
|
||||
socketio.disconnect();
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
|
||||
@@ -2,35 +2,18 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
|
||||
import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer, {
|
||||
GalleryState,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer, {
|
||||
resultsAdapter,
|
||||
ResultsState,
|
||||
} from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import lightboxReducer, {
|
||||
LightboxState,
|
||||
} from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer, {
|
||||
GenerationState,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer, {
|
||||
PostprocessingState,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import { CanvasState } from 'features/canvas/store/canvasTypes';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@@ -46,21 +29,13 @@ import { CanvasState } from 'features/canvas/store/canvasTypes';
|
||||
* The necesssary nested persistors with blacklists are configured below.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Canvas slice persist blacklist
|
||||
*/
|
||||
const canvasBlacklist: (keyof CanvasState)[] = [
|
||||
const canvasBlacklist = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
];
|
||||
].map((blacklistItem) => `canvas.${blacklistItem}`);
|
||||
|
||||
canvasBlacklist.map((blacklistItem) => `canvas.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* System slice persist blacklist
|
||||
*/
|
||||
const systemBlacklist: (keyof SystemState)[] = [
|
||||
const systemBlacklist = [
|
||||
'currentIteration',
|
||||
'currentStatus',
|
||||
'currentStep',
|
||||
@@ -73,101 +48,30 @@ const systemBlacklist: (keyof SystemState)[] = [
|
||||
'totalIterations',
|
||||
'totalSteps',
|
||||
'openModel',
|
||||
'isCancelScheduled',
|
||||
'sessionId',
|
||||
'progressImage',
|
||||
];
|
||||
'cancelOptions.cancelAfter',
|
||||
].map((blacklistItem) => `system.${blacklistItem}`);
|
||||
|
||||
systemBlacklist.map((blacklistItem) => `system.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Gallery slice persist blacklist
|
||||
*/
|
||||
const galleryBlacklist: (keyof GalleryState)[] = [
|
||||
const galleryBlacklist = [
|
||||
'categories',
|
||||
'currentCategory',
|
||||
'currentImage',
|
||||
'currentImageUuid',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
'intermediateImage',
|
||||
];
|
||||
].map((blacklistItem) => `gallery.${blacklistItem}`);
|
||||
|
||||
galleryBlacklist.map((blacklistItem) => `gallery.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Lightbox slice persist blacklist
|
||||
*/
|
||||
const lightboxBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
||||
|
||||
lightboxBlacklist.map((blacklistItem) => `lightbox.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Nodes slice persist blacklist
|
||||
*/
|
||||
const nodesBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
|
||||
|
||||
nodesBlacklist.map((blacklistItem) => `nodes.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Generation slice persist blacklist
|
||||
*/
|
||||
const generationBlacklist: (keyof GenerationState)[] = [];
|
||||
|
||||
generationBlacklist.map((blacklistItem) => `generation.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Postprocessing slice persist blacklist
|
||||
*/
|
||||
const postprocessingBlacklist: (keyof PostprocessingState)[] = [];
|
||||
|
||||
postprocessingBlacklist.map(
|
||||
(blacklistItem) => `postprocessing.${blacklistItem}`
|
||||
const lightboxBlacklist = ['isLightboxOpen'].map(
|
||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||
);
|
||||
|
||||
/**
|
||||
* Results slice persist blacklist
|
||||
*
|
||||
* Currently blacklisting results slice entirely, see persist config below
|
||||
*/
|
||||
const resultsBlacklist: (keyof ResultsState)[] = [];
|
||||
|
||||
resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Uploads slice persist blacklist
|
||||
*
|
||||
* Currently blacklisting uploads slice entirely, see persist config below
|
||||
*/
|
||||
const uploadsBlacklist: (keyof NodesState)[] = [];
|
||||
|
||||
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* Models slice persist blacklist
|
||||
*/
|
||||
const modelsBlacklist: (keyof NodesState)[] = [];
|
||||
|
||||
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
|
||||
|
||||
/**
|
||||
* UI slice persist blacklist
|
||||
*/
|
||||
const uiBlacklist: (keyof NodesState)[] = [];
|
||||
|
||||
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
canvas: canvasReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
results: resultsReducer,
|
||||
gallery: galleryReducer,
|
||||
system: systemReducer,
|
||||
canvas: canvasReducer,
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
lightbox: lightboxReducer,
|
||||
});
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
@@ -176,40 +80,23 @@ const rootPersistConfig = getPersistConfig({
|
||||
rootReducer,
|
||||
blacklist: [
|
||||
...canvasBlacklist,
|
||||
...galleryBlacklist,
|
||||
...generationBlacklist,
|
||||
...lightboxBlacklist,
|
||||
...modelsBlacklist,
|
||||
...nodesBlacklist,
|
||||
...postprocessingBlacklist,
|
||||
// ...resultsBlacklist,
|
||||
'results',
|
||||
...systemBlacklist,
|
||||
...uiBlacklist,
|
||||
// ...uploadsBlacklist,
|
||||
'uploads',
|
||||
...galleryBlacklist,
|
||||
...lightboxBlacklist,
|
||||
],
|
||||
debounce: 300,
|
||||
});
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
|
||||
// TODO: rip the old middleware out when nodes is complete
|
||||
export function buildMiddleware() {
|
||||
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
return socketMiddleware();
|
||||
} else {
|
||||
return socketioMiddleware();
|
||||
}
|
||||
}
|
||||
|
||||
// Continue with store setup
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(dynamicMiddlewares),
|
||||
}).concat(socketioMiddleware()),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||
import { AppDispatch, RootState } from './store';
|
||||
|
||||
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||
state: RootState;
|
||||
dispatch: AppDispatch;
|
||||
}>();
|
||||
@@ -2,6 +2,7 @@ import { Box, useToast } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import {
|
||||
@@ -14,7 +15,6 @@ import {
|
||||
} from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
|
||||
type ImageUploaderProps = {
|
||||
@@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
return () => {
|
||||
|
||||
@@ -1,160 +1,27 @@
|
||||
// import WorkInProgress from './WorkInProgress';
|
||||
// import ReactFlow, {
|
||||
// applyEdgeChanges,
|
||||
// applyNodeChanges,
|
||||
// Background,
|
||||
// Controls,
|
||||
// Edge,
|
||||
// Handle,
|
||||
// Node,
|
||||
// NodeTypes,
|
||||
// OnEdgesChange,
|
||||
// OnNodesChange,
|
||||
// Position,
|
||||
// } from 'reactflow';
|
||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
|
||||
// import 'reactflow/dist/style.css';
|
||||
// import {
|
||||
// Fragment,
|
||||
// FunctionComponent,
|
||||
// ReactNode,
|
||||
// useCallback,
|
||||
// useMemo,
|
||||
// useState,
|
||||
// } from 'react';
|
||||
// import { OpenAPIV3 } from 'openapi-types';
|
||||
// import { filter, map, reduce } from 'lodash';
|
||||
// import {
|
||||
// Box,
|
||||
// Flex,
|
||||
// FormControl,
|
||||
// FormLabel,
|
||||
// Input,
|
||||
// Select,
|
||||
// Switch,
|
||||
// Text,
|
||||
// NumberInput,
|
||||
// NumberInputField,
|
||||
// NumberInputStepper,
|
||||
// NumberIncrementStepper,
|
||||
// NumberDecrementStepper,
|
||||
// Tooltip,
|
||||
// chakra,
|
||||
// Badge,
|
||||
// Heading,
|
||||
// VStack,
|
||||
// HStack,
|
||||
// Menu,
|
||||
// MenuButton,
|
||||
// MenuList,
|
||||
// MenuItem,
|
||||
// MenuItemOption,
|
||||
// MenuGroup,
|
||||
// MenuOptionGroup,
|
||||
// MenuDivider,
|
||||
// IconButton,
|
||||
// } from '@chakra-ui/react';
|
||||
// import { FaPlus } from 'react-icons/fa';
|
||||
// import {
|
||||
// FIELD_NAMES as FIELD_NAMES,
|
||||
// FIELDS,
|
||||
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||
// INVOCATIONS,
|
||||
// } from 'features/nodeEditor/constants';
|
||||
|
||||
// console.log('invocations', INVOCATIONS);
|
||||
|
||||
// const nodeTypes = reduce(
|
||||
// INVOCATIONS,
|
||||
// (acc, val, key) => {
|
||||
// acc[key] = val.component;
|
||||
// return acc;
|
||||
// },
|
||||
// {} as NodeTypes
|
||||
// );
|
||||
|
||||
// console.log('nodeTypes', nodeTypes);
|
||||
|
||||
// // make initial nodes one of every node for now
|
||||
// let n = 0;
|
||||
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||
// id: i.type,
|
||||
// type: i.title,
|
||||
// position: { x: (n += 20), y: (n += 20) },
|
||||
// data: {},
|
||||
// }));
|
||||
|
||||
// console.log('initialNodes', initialNodes);
|
||||
|
||||
// export default function NodesWIP() {
|
||||
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||
|
||||
// const onNodesChange: OnNodesChange = useCallback(
|
||||
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// return (
|
||||
// <Box
|
||||
// sx={{
|
||||
// position: 'relative',
|
||||
// width: 'full',
|
||||
// height: 'full',
|
||||
// borderRadius: 'md',
|
||||
// }}
|
||||
// >
|
||||
// <ReactFlow
|
||||
// nodeTypes={nodeTypes}
|
||||
// nodes={nodes}
|
||||
// edges={edges}
|
||||
// onNodesChange={onNodesChange}
|
||||
// onEdgesChange={onEdgesChange}
|
||||
// >
|
||||
// <Background />
|
||||
// <Controls />
|
||||
// </ReactFlow>
|
||||
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
// {FIELD_NAMES.map((field) => (
|
||||
// <Badge
|
||||
// key={field}
|
||||
// colorScheme={FIELDS[field].color}
|
||||
// sx={{ userSelect: 'none' }}
|
||||
// >
|
||||
// {field}
|
||||
// </Badge>
|
||||
// ))}
|
||||
// </HStack>
|
||||
// <Menu>
|
||||
// <MenuButton
|
||||
// as={IconButton}
|
||||
// aria-label="Options"
|
||||
// icon={<FaPlus />}
|
||||
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
// />
|
||||
// <MenuList>
|
||||
// {INVOCATION_NAMES.map((name) => {
|
||||
// const invocation = INVOCATIONS[name];
|
||||
// return (
|
||||
// <Tooltip
|
||||
// key={name}
|
||||
// label={invocation.description}
|
||||
// placement="end"
|
||||
// hasArrow
|
||||
// >
|
||||
// <MenuItem>{invocation.title}</MenuItem>
|
||||
// </Tooltip>
|
||||
// );
|
||||
// })}
|
||||
// </MenuList>
|
||||
// </Menu>
|
||||
// </Box>
|
||||
// );
|
||||
// }
|
||||
|
||||
export default {};
|
||||
export default function NodesWIP() {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: '100%',
|
||||
h: '100%',
|
||||
gap: 4,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<Heading>{t('common.nodes')}</Heading>
|
||||
<VStack maxW="50rem" gap={4}>
|
||||
<Text>{t('common.nodesDesc')}</Text>
|
||||
</VStack>
|
||||
</Flex>
|
||||
</WorkInProgress>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { find } from 'lodash';
|
||||
import {
|
||||
Graph,
|
||||
ImageToImageInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { buildHiResNode, buildImg2ImgNode } from './nodes/image2Image';
|
||||
import { buildIteration } from './nodes/iteration';
|
||||
import { buildTxt2ImgNode } from './nodes/text2Image';
|
||||
|
||||
function mapTabToFunction(activeTabName: InvokeTabName) {
|
||||
switch (activeTabName) {
|
||||
case 'txt2img':
|
||||
return buildTxt2ImgNode;
|
||||
|
||||
case 'img2img':
|
||||
return buildImg2ImgNode;
|
||||
|
||||
default:
|
||||
return buildTxt2ImgNode;
|
||||
}
|
||||
}
|
||||
|
||||
const buildBaseNode = (
|
||||
state: RootState
|
||||
): Record<string, TextToImageInvocation | ImageToImageInvocation> => {
|
||||
const { activeTab } = state.ui;
|
||||
const activeTabName = tabMap[activeTab];
|
||||
|
||||
return mapTabToFunction(activeTabName)(state);
|
||||
};
|
||||
|
||||
type BuildGraphOutput = {
|
||||
graph: Graph;
|
||||
nodeIdsToSubscribe: string[];
|
||||
};
|
||||
|
||||
export const buildGraph = (state: RootState): BuildGraphOutput => {
|
||||
const { generation, postprocessing } = state;
|
||||
const { iterations } = generation;
|
||||
const { hiresFix, hiresStrength } = postprocessing;
|
||||
|
||||
const baseNode = buildBaseNode(state);
|
||||
|
||||
let graph: Graph = { nodes: baseNode };
|
||||
const nodeIdsToSubscribe: string[] = [];
|
||||
|
||||
if (iterations > 1) {
|
||||
graph = buildIteration({ graph, iterations });
|
||||
}
|
||||
|
||||
if (hiresFix) {
|
||||
const { node, edge } = buildHiResNode(
|
||||
baseNode as Record<string, TextToImageInvocation>,
|
||||
hiresStrength
|
||||
);
|
||||
graph = {
|
||||
nodes: {
|
||||
...graph.nodes,
|
||||
...node,
|
||||
},
|
||||
edges: [...(graph.edges || []), edge],
|
||||
};
|
||||
nodeIdsToSubscribe.push(Object.keys(node)[0]);
|
||||
}
|
||||
|
||||
console.log('buildGraph: ', graph);
|
||||
|
||||
return { graph, nodeIdsToSubscribe };
|
||||
};
|
||||
@@ -1,6 +0,0 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
||||
@@ -1,28 +0,0 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { OpenAPI } from 'services/api';
|
||||
|
||||
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
};
|
||||
|
||||
export const useGetUrl = () => {
|
||||
const shouldTransformUrls = useAppSelector(
|
||||
(state: RootState) => state.system.shouldTransformUrls
|
||||
);
|
||||
|
||||
return {
|
||||
shouldTransformUrls,
|
||||
getUrl: (url: string) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
},
|
||||
};
|
||||
};
|
||||
@@ -1,98 +0,0 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { _Image } from 'app/invokeai';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
|
||||
export const buildImg2ImgNode = (
|
||||
state: RootState
|
||||
): Record<string, ImageToImageInvocation> => {
|
||||
const nodeId = uuidv4();
|
||||
const { generation, system, models } = state;
|
||||
|
||||
const { shouldDisplayInProgressType } = system;
|
||||
const { currentModel: model } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale,
|
||||
sampler,
|
||||
seamless,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight: fit,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
const initialImage = initialImageSelector(state);
|
||||
|
||||
if (!initialImage) {
|
||||
// TODO: handle this
|
||||
throw 'no initial image';
|
||||
}
|
||||
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale: cfgScale,
|
||||
scheduler: sampler as ImageToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
image: {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
type hiresReturnType = {
|
||||
node: Record<string, ImageToImageInvocation>;
|
||||
edge: Edge;
|
||||
};
|
||||
|
||||
export const buildHiResNode = (
|
||||
baseNode: Record<string, TextToImageInvocation>,
|
||||
strength?: number
|
||||
): hiresReturnType => {
|
||||
const nodeId = uuidv4();
|
||||
const baseNodeId = Object.keys(baseNode)[0];
|
||||
const baseNodeValues = Object.values(baseNode)[0];
|
||||
|
||||
return {
|
||||
node: {
|
||||
[nodeId]: {
|
||||
...baseNodeValues,
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
strength,
|
||||
},
|
||||
},
|
||||
edge: {
|
||||
source: {
|
||||
field: 'image',
|
||||
node_id: baseNodeId,
|
||||
},
|
||||
destination: {
|
||||
field: 'image',
|
||||
node_id: nodeId,
|
||||
},
|
||||
},
|
||||
};
|
||||
};
|
||||
@@ -1,81 +0,0 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import {
|
||||
Edge,
|
||||
Graph,
|
||||
ImageToImageInvocation,
|
||||
IterateInvocation,
|
||||
RangeInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { buildImg2ImgNode } from './image2Image';
|
||||
|
||||
type BuildIteration = {
|
||||
graph: Graph;
|
||||
iterations: number;
|
||||
};
|
||||
|
||||
const buildRangeNode = (
|
||||
iterations: number
|
||||
): Record<string, RangeInvocation> => {
|
||||
const nodeId = uuidv4();
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'range',
|
||||
start: 0,
|
||||
stop: iterations,
|
||||
step: 1,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
const buildIterateNode = (): Record<string, IterateInvocation> => {
|
||||
const nodeId = uuidv4();
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'iterate',
|
||||
collection: [],
|
||||
index: 0,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export const buildIteration = ({
|
||||
graph,
|
||||
iterations,
|
||||
}: BuildIteration): Graph => {
|
||||
const rangeNode = buildRangeNode(iterations);
|
||||
const iterateNode = buildIterateNode();
|
||||
const baseNode: Graph['nodes'] = graph.nodes;
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
source: {
|
||||
field: 'collection',
|
||||
node_id: Object.keys(rangeNode)[0],
|
||||
},
|
||||
destination: {
|
||||
field: 'collection',
|
||||
node_id: Object.keys(iterateNode)[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
field: 'item',
|
||||
node_id: Object.keys(iterateNode)[0],
|
||||
},
|
||||
destination: {
|
||||
field: 'seed',
|
||||
node_id: Object.keys(baseNode!)[0],
|
||||
},
|
||||
},
|
||||
];
|
||||
return {
|
||||
nodes: {
|
||||
...rangeNode,
|
||||
...iterateNode,
|
||||
...graph.nodes,
|
||||
},
|
||||
edges,
|
||||
};
|
||||
};
|
||||
@@ -1,43 +0,0 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { RootState } from 'app/store';
|
||||
import { TextToImageInvocation } from 'services/api';
|
||||
|
||||
export const buildTxt2ImgNode = (
|
||||
state: RootState
|
||||
): Record<string, TextToImageInvocation> => {
|
||||
const nodeId = uuidv4();
|
||||
const { generation, system, models } = state;
|
||||
|
||||
const { shouldDisplayInProgressType } = system;
|
||||
const { currentModel: model } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale: cfg_scale,
|
||||
sampler,
|
||||
seamless,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
// missing fields in TextToImageInvocation: strength, hires_fix
|
||||
return {
|
||||
[nodeId]: {
|
||||
id: nodeId,
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
seed: shouldRandomizeSeed ? -1 : seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
scheduler: sampler as TextToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
},
|
||||
};
|
||||
};
|
||||
@@ -1,10 +1,8 @@
|
||||
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { buildMiddleware, store } from './app/store';
|
||||
import { store } from './app/store';
|
||||
import { persistor } from './persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
@@ -19,61 +17,18 @@ import Loading from './Loading';
|
||||
|
||||
// Localization
|
||||
import './i18n';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
shouldTransformUrls?: boolean;
|
||||
}
|
||||
|
||||
export default function Component({
|
||||
apiUrl,
|
||||
disabledPanels = [],
|
||||
disabledTabs = [],
|
||||
token,
|
||||
children,
|
||||
shouldTransformUrls,
|
||||
}: Props) {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
OpenAPI.TOKEN = token;
|
||||
}
|
||||
|
||||
// configure API client base url
|
||||
if (apiUrl) {
|
||||
OpenAPI.BASE = apiUrl;
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
resetMiddlewares();
|
||||
|
||||
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
|
||||
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
|
||||
// outside the provider. it's not needed until there is the possibility that we will change
|
||||
// the `apiUrl`/`token` dynamically.
|
||||
|
||||
// rebuild socket middleware with token and apiUrl
|
||||
addMiddleware(buildMiddleware());
|
||||
}, [apiUrl, token]);
|
||||
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App
|
||||
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
|
||||
>
|
||||
{children}
|
||||
</App>
|
||||
<App>{props.children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
||||
@@ -5,8 +5,6 @@ import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
@@ -15,6 +13,4 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||
import { isEqual } from 'lodash';
|
||||
@@ -26,7 +25,7 @@ type Props = Omit<ImageConfig, 'image'>;
|
||||
const IAICanvasIntermediateImage = (props: Props) => {
|
||||
const { ...rest } = props;
|
||||
const intermediateImage = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const [loadedImageElement, setLoadedImageElement] =
|
||||
useState<HTMLImageElement | null>(null);
|
||||
|
||||
@@ -37,8 +36,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
||||
tempImage.onload = () => {
|
||||
setLoadedImageElement(tempImage);
|
||||
};
|
||||
tempImage.src = getUrl(intermediateImage.url);
|
||||
}, [intermediateImage, getUrl]);
|
||||
tempImage.src = intermediateImage.url;
|
||||
}, [intermediateImage]);
|
||||
|
||||
if (!intermediateImage?.boundingBox) return null;
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import { isEqual } from 'lodash';
|
||||
@@ -33,7 +32,6 @@ const selector = createSelector(
|
||||
|
||||
const IAICanvasObjectRenderer = () => {
|
||||
const { objects } = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
if (!objects) return null;
|
||||
|
||||
@@ -42,12 +40,7 @@ const IAICanvasObjectRenderer = () => {
|
||||
{objects.map((obj, i) => {
|
||||
if (isCanvasBaseImage(obj)) {
|
||||
return (
|
||||
<IAICanvasImage
|
||||
key={i}
|
||||
x={obj.x}
|
||||
y={obj.y}
|
||||
url={getUrl(obj.image.url)}
|
||||
/>
|
||||
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
|
||||
);
|
||||
} else if (isCanvasBaseLine(obj)) {
|
||||
const line = (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { GroupConfig } from 'konva/lib/Group';
|
||||
import { isEqual } from 'lodash';
|
||||
@@ -54,16 +53,11 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
width,
|
||||
height,
|
||||
} = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
return (
|
||||
<Group {...rest}>
|
||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||
<IAICanvasImage
|
||||
url={getUrl(currentStagingAreaImage.image.url)}
|
||||
x={x}
|
||||
y={y}
|
||||
/>
|
||||
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
|
||||
)}
|
||||
{shouldShowStagingOutline && (
|
||||
<Group>
|
||||
|
||||
@@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
|
||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||
state.cursorPosition = action.payload;
|
||||
},
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
const image = action.payload;
|
||||
const { stageDimensions } = state;
|
||||
|
||||
@@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
boundingBox: IRect;
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
}>
|
||||
) => {
|
||||
const { boundingBox, image } = action.payload;
|
||||
|
||||
@@ -37,7 +37,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
@@ -125,7 +125,7 @@ export interface CanvasState {
|
||||
cursorPosition: Vector2d | null;
|
||||
doesCanvasNeedScaling: boolean;
|
||||
futureLayerStates: CanvasLayerState[];
|
||||
intermediateImage?: InvokeAI._Image;
|
||||
intermediateImage?: InvokeAI.Image;
|
||||
isCanvasInitialized: boolean;
|
||||
isDrawing: boolean;
|
||||
isMaskEnabled: boolean;
|
||||
|
||||
@@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
|
||||
|
||||
const { url, width, height } = image;
|
||||
|
||||
const newImage: InvokeAI._Image = {
|
||||
const newImage: InvokeAI.Image = {
|
||||
uuid: uuidv4(),
|
||||
category: shouldSaveToGallery ? 'result' : 'user',
|
||||
...image,
|
||||
|
||||
@@ -14,9 +14,8 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllParameters,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
@@ -28,6 +27,7 @@ import {
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldHidePreview,
|
||||
setShouldShowImageDetails,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@@ -39,6 +39,8 @@ import {
|
||||
FaDownload,
|
||||
FaExpand,
|
||||
FaExpandArrowsAlt,
|
||||
FaEye,
|
||||
FaEyeSlash,
|
||||
FaGrinStars,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
@@ -46,15 +48,11 @@ import {
|
||||
FaShareAlt,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { useCallback } from 'react';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
@@ -64,7 +62,6 @@ const currentImageButtonsSelector = createSelector(
|
||||
uiSelector,
|
||||
lightboxSelector,
|
||||
activeTabNameSelector,
|
||||
selectedImageSelector,
|
||||
],
|
||||
(
|
||||
system: SystemState,
|
||||
@@ -72,8 +69,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
postprocessing,
|
||||
ui,
|
||||
lightbox,
|
||||
activeTabName,
|
||||
selectedImage
|
||||
activeTabName
|
||||
) => {
|
||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||
system;
|
||||
@@ -82,7 +78,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
|
||||
const { isLightboxOpen } = lightbox;
|
||||
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
|
||||
const { intermediateImage, currentImage } = gallery;
|
||||
|
||||
@@ -98,7 +94,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
selectedImage,
|
||||
shouldHidePreview,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -125,32 +121,27 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons,
|
||||
shouldShowImageDetails,
|
||||
// currentImage,
|
||||
currentImage,
|
||||
isLightboxOpen,
|
||||
activeTabName,
|
||||
selectedImage,
|
||||
shouldHidePreview,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
const { getUrl, shouldTransformUrls } = useGetUrl();
|
||||
|
||||
const toast = useToast();
|
||||
const { t } = useTranslation();
|
||||
const setBothPrompts = useSetBothPrompts();
|
||||
|
||||
const handleClickUseAsInitialImage = () => {
|
||||
if (!selectedImage) return;
|
||||
if (!currentImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
dispatch(initialImageSelected(selectedImage.name));
|
||||
// dispatch(setInitialImage(currentImage));
|
||||
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
dispatch(setInitialImage(currentImage));
|
||||
dispatch(setActiveTab('img2img'));
|
||||
};
|
||||
|
||||
const handleCopyImage = async () => {
|
||||
if (!selectedImage) return;
|
||||
if (!currentImage) return;
|
||||
|
||||
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
|
||||
res.blob()
|
||||
);
|
||||
const blob = await fetch(currentImage.url).then((res) => res.blob());
|
||||
const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||
|
||||
await navigator.clipboard.write(data);
|
||||
@@ -164,26 +155,24 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
};
|
||||
|
||||
const handleCopyImageLink = () => {
|
||||
const url = selectedImage
|
||||
? shouldTransformUrls
|
||||
? getUrl(selectedImage.url)
|
||||
: window.location.toString() + selectedImage.url
|
||||
: '';
|
||||
|
||||
navigator.clipboard.writeText(url).then(() => {
|
||||
toast({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
navigator.clipboard
|
||||
.writeText(
|
||||
currentImage ? window.location.toString() + currentImage.url : ''
|
||||
)
|
||||
.then(() => {
|
||||
toast({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
'shift+i',
|
||||
() => {
|
||||
if (selectedImage) {
|
||||
if (currentImage) {
|
||||
handleClickUseAsInitialImage();
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
@@ -201,27 +190,28 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[selectedImage]
|
||||
[currentImage]
|
||||
);
|
||||
|
||||
const handlePreviewVisibility = () => {
|
||||
dispatch(setShouldHidePreview(!shouldHidePreview));
|
||||
};
|
||||
|
||||
const handleClickUseAllParameters = () => {
|
||||
if (!selectedImage) return;
|
||||
// selectedImage.metadata &&
|
||||
// dispatch(setAllParameters(selectedImage.metadata));
|
||||
// if (selectedImage.metadata?.image.type === 'img2img') {
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
|
||||
// dispatch(setActiveTab('txt2img'));
|
||||
// }
|
||||
if (!currentImage) return;
|
||||
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
|
||||
if (currentImage.metadata?.image.type === 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
} else if (currentImage.metadata?.image.type === 'txt2img') {
|
||||
dispatch(setActiveTab('txt2img'));
|
||||
}
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
'a',
|
||||
() => {
|
||||
if (
|
||||
['txt2img', 'img2img'].includes(
|
||||
selectedImage?.metadata?.sd_metadata?.type
|
||||
)
|
||||
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
|
||||
) {
|
||||
handleClickUseAllParameters();
|
||||
toast({
|
||||
@@ -240,18 +230,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[selectedImage]
|
||||
[currentImage]
|
||||
);
|
||||
|
||||
const handleClickUseSeed = () => {
|
||||
selectedImage?.metadata &&
|
||||
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
|
||||
currentImage?.metadata &&
|
||||
dispatch(setSeed(currentImage.metadata.image.seed));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
's',
|
||||
() => {
|
||||
if (selectedImage?.metadata?.sd_metadata?.seed) {
|
||||
if (currentImage?.metadata?.image?.seed) {
|
||||
handleClickUseSeed();
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
@@ -269,19 +259,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[selectedImage]
|
||||
[currentImage]
|
||||
);
|
||||
|
||||
const handleClickUsePrompt = useCallback(() => {
|
||||
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
|
||||
if (currentImage?.metadata?.image?.prompt) {
|
||||
setBothPrompts(currentImage?.metadata?.image?.prompt);
|
||||
}
|
||||
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
|
||||
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
|
||||
|
||||
useHotkeys(
|
||||
'p',
|
||||
() => {
|
||||
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||
if (currentImage?.metadata?.image?.prompt) {
|
||||
handleClickUsePrompt();
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
@@ -299,11 +289,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[selectedImage]
|
||||
[currentImage]
|
||||
);
|
||||
|
||||
const handleClickUpscale = () => {
|
||||
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||
currentImage && dispatch(runESRGAN(currentImage));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
@@ -327,7 +317,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
},
|
||||
[
|
||||
selectedImage,
|
||||
currentImage,
|
||||
isESRGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
@@ -337,7 +327,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
);
|
||||
|
||||
const handleClickFixFaces = () => {
|
||||
// selectedImage && dispatch(runFacetool(selectedImage));
|
||||
currentImage && dispatch(runFacetool(currentImage));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
@@ -361,7 +351,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
},
|
||||
[
|
||||
selectedImage,
|
||||
currentImage,
|
||||
isGFPGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
@@ -374,10 +364,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
if (!selectedImage) return;
|
||||
if (!currentImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
|
||||
// dispatch(setInitialCanvasImage(selectedImage));
|
||||
dispatch(setInitialCanvasImage(currentImage));
|
||||
dispatch(requestCanvasRescale());
|
||||
|
||||
if (activeTabName !== 'unifiedCanvas') {
|
||||
@@ -395,7 +385,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
useHotkeys(
|
||||
'i',
|
||||
() => {
|
||||
if (selectedImage) {
|
||||
if (currentImage) {
|
||||
handleClickShowImageDetails();
|
||||
} else {
|
||||
toast({
|
||||
@@ -406,7 +396,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[selectedImage, shouldShowImageDetails]
|
||||
[currentImage, shouldShowImageDetails]
|
||||
);
|
||||
|
||||
const handleLightBox = () => {
|
||||
@@ -467,13 +457,28 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
{t('parameters.copyImageToLink')}
|
||||
</IAIButton>
|
||||
|
||||
<Link download={true} href={getUrl(selectedImage!.url)}>
|
||||
<Link download={true} href={currentImage?.url}>
|
||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||
{t('parameters.downloadImage')}
|
||||
</IAIButton>
|
||||
</Link>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
<IAIIconButton
|
||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
||||
tooltip={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
aria-label={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
isChecked={shouldHidePreview}
|
||||
onClick={handlePreviewVisibility}
|
||||
/>
|
||||
<IAIIconButton
|
||||
icon={<FaExpand />}
|
||||
tooltip={
|
||||
@@ -496,7 +501,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
|
||||
isDisabled={!currentImage?.metadata?.image?.prompt}
|
||||
onClick={handleClickUsePrompt}
|
||||
/>
|
||||
|
||||
@@ -504,7 +509,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
|
||||
isDisabled={!currentImage?.metadata?.image?.seed}
|
||||
onClick={handleClickUseSeed}
|
||||
/>
|
||||
|
||||
@@ -514,7 +519,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
selectedImage?.metadata?.sd_metadata?.type
|
||||
currentImage?.metadata?.image?.type
|
||||
)
|
||||
}
|
||||
onClick={handleClickUseAllParameters}
|
||||
@@ -540,7 +545,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
!selectedImage ||
|
||||
!currentImage ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!facetoolStrength
|
||||
}
|
||||
@@ -569,7 +574,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
!selectedImage ||
|
||||
!currentImage ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!upscalingLevel
|
||||
}
|
||||
@@ -591,15 +596,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
{/* <DeleteImageModal image={selectedImage}>
|
||||
<DeleteImageModal image={currentImage}>
|
||||
<IAIIconButton
|
||||
icon={<FaTrash />}
|
||||
tooltip={`${t('parameters.deleteImage')} (Del)`}
|
||||
aria-label={`${t('parameters.deleteImage')} (Del)`}
|
||||
isDisabled={!selectedImage || !isConnected || isProcessing}
|
||||
isDisabled={!currentImage || !isConnected || isProcessing}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</DeleteImageModal> */}
|
||||
</DeleteImageModal>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -4,20 +4,17 @@ import { useAppSelector } from 'app/storeHooks';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
import { MdPhoto } from 'react-icons/md';
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
|
||||
export const currentImageDisplaySelector = createSelector(
|
||||
[gallerySelector, selectedImageSelector],
|
||||
(gallery, selectedImage) => {
|
||||
[gallerySelector],
|
||||
(gallery) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
|
||||
return {
|
||||
hasAnImageToDisplay: selectedImage || intermediateImage,
|
||||
hasAnImageToDisplay: currentImage || intermediateImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { FaEyeSlash } from 'react-icons/fa';
|
||||
|
||||
const CurrentImageHidden = () => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<FaEyeSlash size={'30vh'} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageHidden;
|
||||
@@ -1,46 +1,28 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ReactEventHandler } from 'react';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import CurrentImageHidden from './CurrentImageHidden';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, selectedImageSelector, systemSelector],
|
||||
(ui, selectedImage, system) => {
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { progressImage } = system;
|
||||
|
||||
// TODO: Clean this up, this is really gross
|
||||
const imageToDisplay = progressImage
|
||||
? {
|
||||
url: progressImage.dataURL,
|
||||
width: progressImage.width,
|
||||
height: progressImage.height,
|
||||
isProgressImage: true,
|
||||
image: progressImage,
|
||||
}
|
||||
: selectedImage
|
||||
? {
|
||||
url: selectedImage.url,
|
||||
width: selectedImage.metadata.width,
|
||||
height: selectedImage.metadata.height,
|
||||
isProgressImage: false,
|
||||
image: selectedImage,
|
||||
}
|
||||
: null;
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery: GalleryState, ui) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
|
||||
return {
|
||||
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
|
||||
isIntermediate: Boolean(intermediateImage),
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
shouldHidePreview,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -51,9 +33,12 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
export default function CurrentImagePreview() {
|
||||
const { shouldShowImageDetails, imageToDisplay } =
|
||||
useAppSelector(imagesSelector);
|
||||
const { getUrl } = useGetUrl();
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
isIntermediate,
|
||||
shouldHidePreview,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -67,15 +52,13 @@ export default function CurrentImagePreview() {
|
||||
>
|
||||
{imageToDisplay && (
|
||||
<Image
|
||||
src={
|
||||
imageToDisplay.isProgressImage
|
||||
? imageToDisplay.url
|
||||
: getUrl(imageToDisplay.url)
|
||||
}
|
||||
src={shouldHidePreview ? undefined : imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={
|
||||
!imageToDisplay.isProgressImage ? (
|
||||
shouldHidePreview ? (
|
||||
<CurrentImageHidden />
|
||||
) : !isIntermediate ? (
|
||||
<CurrentImageFallback />
|
||||
) : undefined
|
||||
}
|
||||
@@ -85,31 +68,27 @@ export default function CurrentImagePreview() {
|
||||
maxHeight: '100%',
|
||||
height: 'auto',
|
||||
position: 'absolute',
|
||||
imageRendering: imageToDisplay.isProgressImage
|
||||
? 'pixelated'
|
||||
: 'initial',
|
||||
imageRendering: isIntermediate ? 'pixelated' : 'initial',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
||||
{shouldShowImageDetails &&
|
||||
imageToDisplay &&
|
||||
'metadata' in imageToDisplay.image && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay.image} />
|
||||
</Box>
|
||||
)}
|
||||
{shouldShowImageDetails && imageToDisplay && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay} />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ interface DeleteImageModalProps {
|
||||
/**
|
||||
* The image to delete.
|
||||
*/
|
||||
image?: InvokeAI._Image;
|
||||
image?: InvokeAI.Image;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,14 +9,11 @@ import {
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
setCurrentImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllImageToImageParameters,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { DragEvent, memo, useState } from 'react';
|
||||
@@ -34,7 +31,6 @@ import { useTranslation } from 'react-i18next';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
interface HoverableImageProps {
|
||||
image: InvokeAI.Image;
|
||||
@@ -44,7 +40,7 @@ interface HoverableImageProps {
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
|
||||
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
@@ -59,8 +55,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(hoverableImageSelector);
|
||||
const { image, isSelected } = props;
|
||||
const { url, thumbnail, name, metadata } = image;
|
||||
const { getUrl } = useGetUrl();
|
||||
const { url, thumbnail, uuid, metadata } = image;
|
||||
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
|
||||
@@ -74,9 +69,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleUsePrompt = () => {
|
||||
if (image.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(image.metadata?.sd_metadata?.prompt);
|
||||
if (image.metadata?.image?.prompt) {
|
||||
setBothPrompts(image.metadata?.image?.prompt);
|
||||
}
|
||||
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
status: 'success',
|
||||
@@ -86,8 +82,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseSeed = () => {
|
||||
image.metadata.sd_metadata &&
|
||||
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
|
||||
image.metadata && dispatch(setSeed(image.metadata.image.seed));
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
status: 'success',
|
||||
@@ -97,11 +92,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleSendToImageToImage = () => {
|
||||
dispatch(initialImageSelected(image.name));
|
||||
dispatch(setInitialImage(image));
|
||||
if (activeTabName !== 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
}
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
// dispatch(setInitialCanvasImage(image));
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
|
||||
@@ -118,7 +122,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseAllParameters = () => {
|
||||
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
|
||||
metadata && dispatch(setAllParameters(metadata));
|
||||
toast({
|
||||
title: t('toast.parametersSet'),
|
||||
status: 'success',
|
||||
@@ -128,13 +132,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseInitialImage = async () => {
|
||||
if (metadata.sd_metadata?.image?.init_image_path) {
|
||||
const response = await fetch(
|
||||
metadata.sd_metadata?.image?.init_image_path
|
||||
);
|
||||
if (metadata?.image?.init_image_path) {
|
||||
const response = await fetch(metadata.image.init_image_path);
|
||||
if (response.ok) {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
|
||||
dispatch(setAllImageToImageParameters(metadata));
|
||||
toast({
|
||||
title: t('toast.initialImageSet'),
|
||||
status: 'success',
|
||||
@@ -153,18 +155,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleSelectImage = () => {
|
||||
dispatch(imageSelected(image.name));
|
||||
};
|
||||
const handleSelectImage = () => dispatch(setCurrentImage(image));
|
||||
|
||||
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
||||
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
||||
// e.dataTransfer.effectAllowed = 'move';
|
||||
e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
};
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
dispatch(setCurrentImage(image));
|
||||
dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -177,30 +177,28 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUsePrompt}
|
||||
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
|
||||
isDisabled={image?.metadata?.image?.prompt === undefined}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
onClickCapture={handleUseSeed}
|
||||
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
|
||||
isDisabled={image?.metadata?.image?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
image?.metadata?.sd_metadata?.type
|
||||
)
|
||||
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseInitialImage}
|
||||
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
|
||||
isDisabled={image?.metadata?.image?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem>
|
||||
@@ -211,9 +209,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
<MenuItem data-warning>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<DeleteImageModal image={image}>
|
||||
<p>{t('parameters.deleteImage')}</p>
|
||||
</DeleteImageModal> */}
|
||||
</DeleteImageModal>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
@@ -221,7 +219,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={name}
|
||||
key={uuid}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
userSelect="none"
|
||||
@@ -246,7 +244,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||
}
|
||||
rounded="md"
|
||||
src={getUrl(thumbnail || url)}
|
||||
src={thumbnail || url}
|
||||
loading="lazy"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@@ -292,7 +290,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
insetInlineEnd: 1,
|
||||
}}
|
||||
>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<DeleteImageModal image={image}>
|
||||
<IAIIconButton
|
||||
aria-label={t('parameters.deleteImage')}
|
||||
icon={<FaTrashAlt />}
|
||||
@@ -300,7 +298,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
fontSize={14}
|
||||
isDisabled={!mayDeleteImage}
|
||||
/>
|
||||
</DeleteImageModal> */}
|
||||
</DeleteImageModal>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
|
||||
import { requestImages } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@@ -25,44 +25,9 @@ import HoverableImage from './HoverableImage';
|
||||
|
||||
import Scrollable from 'features/ui/components/common/Scrollable';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import {
|
||||
resultsAdapter,
|
||||
selectResultsAll,
|
||||
selectResultsTotal,
|
||||
} from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
|
||||
const gallerySelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state.uploads,
|
||||
(state: RootState) => state.results,
|
||||
(state: RootState) => state.gallery,
|
||||
],
|
||||
(uploads, results, gallery) => {
|
||||
const { currentCategory } = gallery;
|
||||
|
||||
return currentCategory === 'result'
|
||||
? {
|
||||
images: resultsAdapter.getSelectors().selectAll(results),
|
||||
isLoading: results.isLoading,
|
||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||
}
|
||||
: {
|
||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||
isLoading: uploads.isLoading,
|
||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@@ -70,7 +35,7 @@ const ImageGalleryContent = () => {
|
||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||
|
||||
const {
|
||||
// images,
|
||||
images,
|
||||
currentCategory,
|
||||
currentImageUuid,
|
||||
shouldPinGallery,
|
||||
@@ -78,24 +43,12 @@ const ImageGalleryContent = () => {
|
||||
galleryGridTemplateColumns,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
// areMoreImagesAvailable,
|
||||
areMoreImagesAvailable,
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(imageGallerySelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
useAppSelector(gallerySelector);
|
||||
|
||||
// const handleClickLoadMore = () => {
|
||||
// dispatch(requestImages(currentCategory));
|
||||
// };
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'result') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (currentCategory === 'user') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
dispatch(requestImages(currentCategory));
|
||||
};
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
@@ -250,11 +203,11 @@ const ImageGalleryContent = () => {
|
||||
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
||||
>
|
||||
{images.map((image) => {
|
||||
const { name } = image;
|
||||
const isSelected = currentImageUuid === name;
|
||||
const { uuid } = image;
|
||||
const isSelected = currentImageUuid === uuid;
|
||||
return (
|
||||
<HoverableImage
|
||||
key={name}
|
||||
key={uuid}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
@@ -264,7 +217,6 @@ const ImageGalleryContent = () => {
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
|
||||
@@ -11,7 +11,6 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import promptToString from 'common/util/promptToString';
|
||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
@@ -19,7 +18,7 @@ import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
// setInitialImage,
|
||||
setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
@@ -46,7 +45,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import * as png from '@stevebel/png';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
@@ -122,7 +120,7 @@ type ImageMetadataViewerProps = {
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.name === next.image.name;
|
||||
) => prev.image.uuid === next.image.uuid;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
@@ -139,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
|
||||
const metadata = image?.metadata.sd_metadata || {};
|
||||
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
||||
const metadata = image?.metadata?.image || {};
|
||||
const dreamPrompt = image?.dreamPrompt;
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
@@ -162,23 +160,11 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
type,
|
||||
variations,
|
||||
width,
|
||||
model_weights,
|
||||
} = metadata;
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const metadataJSON = JSON.stringify(image, null, 2);
|
||||
|
||||
// fetch(getUrl(image.url))
|
||||
// .then((r) => r.arrayBuffer())
|
||||
// .then((buffer) => {
|
||||
// const { text } = png.decode(buffer);
|
||||
// const metadata = text?.['sd-metadata']
|
||||
// ? JSON.parse(text['sd-metadata'] ?? {})
|
||||
// : {};
|
||||
// console.log(metadata);
|
||||
// });
|
||||
const metadataJSON = JSON.stringify(image.metadata, null, 2);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -197,49 +183,18 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight="semibold">File:</Text>
|
||||
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
||||
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
|
||||
{image.url.length > 64
|
||||
? image.url.substring(0, 64).concat('...')
|
||||
: image.url}
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</Flex>
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Generation type" value={type} />}
|
||||
{model_weights && (
|
||||
<MetadataItem label="Model" value={model_weights} />
|
||||
{image.metadata?.model_weights && (
|
||||
<MetadataItem label="Model" value={image.metadata.model_weights} />
|
||||
)}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
@@ -333,14 +288,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
onClick={() => dispatch(setHeight(height))}
|
||||
/>
|
||||
)}
|
||||
{/* {init_image_path && (
|
||||
{init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
)}
|
||||
{mask_image_path && (
|
||||
<MetadataItem
|
||||
label="Mask image"
|
||||
@@ -453,6 +408,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
{dreamPrompt && (
|
||||
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
||||
)}
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<Center width="100%" pt={10}>
|
||||
|
||||
@@ -7,16 +7,6 @@ import {
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import {
|
||||
selectResultsAll,
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from './resultsSlice';
|
||||
import {
|
||||
selectUploadsAll,
|
||||
selectUploadsById,
|
||||
selectUploadsEntities,
|
||||
} from './uploadsSlice';
|
||||
|
||||
export const gallerySelector = (state: RootState) => state.gallery;
|
||||
|
||||
@@ -85,18 +75,3 @@ export const hoverableImageSelector = createSelector(
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const selectedImageSelector = createSelector(
|
||||
[gallerySelector, selectResultsEntities, selectUploadsEntities],
|
||||
(gallery, allResults, allUploads) => {
|
||||
const selectedImageName = gallery.selectedImageName;
|
||||
|
||||
if (selectedImageName in allResults) {
|
||||
return allResults[selectedImageName];
|
||||
}
|
||||
|
||||
if (selectedImageName in allUploads) {
|
||||
return allUploads[selectedImageName];
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
|
||||
export type GalleryCategory = 'user' | 'result';
|
||||
|
||||
export type AddImagesPayload = {
|
||||
images: Array<InvokeAI._Image>;
|
||||
images: Array<InvokeAI.Image>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
@@ -19,33 +16,16 @@ export type AddImagesPayload = {
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
export type Gallery = {
|
||||
images: InvokeAI._Image[];
|
||||
images: InvokeAI.Image[];
|
||||
latest_mtime?: number;
|
||||
earliest_mtime?: number;
|
||||
areMoreImagesAvailable: boolean;
|
||||
};
|
||||
|
||||
export interface GalleryState {
|
||||
/**
|
||||
* The selected image's unique name
|
||||
* Use `selectedImageSelector` to access the image
|
||||
*/
|
||||
selectedImageName: string;
|
||||
/**
|
||||
* The currently selected image
|
||||
* @deprecated See `state.gallery.selectedImageName`
|
||||
*/
|
||||
currentImage?: InvokeAI._Image;
|
||||
/**
|
||||
* The currently selected image's uuid.
|
||||
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
|
||||
*/
|
||||
currentImage?: InvokeAI.Image;
|
||||
currentImageUuid: string;
|
||||
/**
|
||||
* The current progress image
|
||||
* @deprecated See `state.system.progressImage`
|
||||
*/
|
||||
intermediateImage?: InvokeAI._Image & {
|
||||
intermediateImage?: InvokeAI.Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
};
|
||||
@@ -62,7 +42,6 @@ export interface GalleryState {
|
||||
}
|
||||
|
||||
const initialState: GalleryState = {
|
||||
selectedImageName: '',
|
||||
currentImageUuid: '',
|
||||
galleryImageMinimumWidth: 64,
|
||||
galleryImageObjectFit: 'cover',
|
||||
@@ -90,10 +69,7 @@ export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState,
|
||||
reducers: {
|
||||
imageSelected: (state, action: PayloadAction<string>) => {
|
||||
state.selectedImageName = action.payload;
|
||||
},
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
state.currentImage = action.payload;
|
||||
state.currentImageUuid = action.payload.uuid;
|
||||
},
|
||||
@@ -148,7 +124,7 @@ export const gallerySlice = createSlice({
|
||||
addImage: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
category: GalleryCategory;
|
||||
}>
|
||||
) => {
|
||||
@@ -174,10 +150,7 @@ export const gallerySlice = createSlice({
|
||||
setIntermediateImage: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
InvokeAI._Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
}
|
||||
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
|
||||
>
|
||||
) => {
|
||||
state.intermediateImage = action.payload;
|
||||
@@ -279,31 +252,9 @@ export const gallerySlice = createSlice({
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
if (isImageOutput(data.result)) {
|
||||
state.selectedImageName = data.result.image.image_name;
|
||||
state.intermediateImage = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { location } = action.payload;
|
||||
const imageName = location.split('/').pop() || '';
|
||||
state.selectedImageName = imageName;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
imageSelected,
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
removeImage,
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import {
|
||||
buildImageUrls,
|
||||
deserializeImageField,
|
||||
extractTimestampFromImageName,
|
||||
} from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { getUrlAlt } from 'common/util/getUrl';
|
||||
import { ImageMetadata } from 'services/api';
|
||||
// import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
|
||||
// use `createEntityAdapter` to create a slice for results images
|
||||
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
|
||||
|
||||
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
|
||||
export const resultsAdapter = createEntityAdapter<Image>({
|
||||
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
|
||||
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
|
||||
selectId: (image) => image.name,
|
||||
// Order all images by their time (in descending order)
|
||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||
});
|
||||
|
||||
// This type is intersected with the Entity type to create the shape of the state
|
||||
type AdditionalResultsState = {
|
||||
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
|
||||
// to list all images directly at this time...
|
||||
page: number; // current page we are on
|
||||
pages: number; // the total number of pages available
|
||||
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
|
||||
nextPage: number; // the next page to request
|
||||
};
|
||||
|
||||
// export type ResultsState = ReturnType<
|
||||
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
|
||||
// >;
|
||||
|
||||
export const initialResultsState =
|
||||
resultsAdapter.getInitialState<AdditionalResultsState>({
|
||||
// provide the additional initial state
|
||||
page: 0,
|
||||
pages: 0,
|
||||
isLoading: false,
|
||||
nextPage: 0,
|
||||
});
|
||||
|
||||
export type ResultsState = typeof initialResultsState;
|
||||
|
||||
const resultsSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: initialResultsState,
|
||||
reducers: {
|
||||
// the adapter provides some helper reducers; see the docs for all of them
|
||||
// can use them as helper functions within a reducer, or use the function itself as a reducer
|
||||
|
||||
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
|
||||
// to add a single result
|
||||
resultAdded: resultsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
|
||||
// because we pass in the fulfilled thunk action creator, everything is typed
|
||||
|
||||
/**
|
||||
* Received Result Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Result Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const resultImages = items.map((image) =>
|
||||
deserializeImageResponse(image)
|
||||
);
|
||||
|
||||
// use the adapter reducer to append all the results to state
|
||||
resultsAdapter.addMany(state, resultImages);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
const { result, invocation, graph_execution_state_id, source_id } = data;
|
||||
|
||||
if (isImageOutput(result)) {
|
||||
const name = result.image.image_name;
|
||||
const type = result.image.image_type;
|
||||
const { url, thumbnail } = buildImageUrls(type, name);
|
||||
|
||||
const timestamp = extractTimestampFromImageName(name);
|
||||
|
||||
const image: Image = {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
metadata: {
|
||||
created: timestamp,
|
||||
width: result.width, // TODO: add tese dimensions
|
||||
height: result.height,
|
||||
invokeai: {
|
||||
session: graph_execution_state_id,
|
||||
source_id,
|
||||
invocation,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// const resultImage = deserializeImageField(result.image, invocation);
|
||||
resultsAdapter.addOne(state, image);
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
// Create a set of memoized selectors based on the location of this entity state
|
||||
// to be used as selectors in a `useAppSelector()` call
|
||||
export const {
|
||||
selectAll: selectResultsAll,
|
||||
selectById: selectResultsById,
|
||||
selectEntities: selectResultsEntities,
|
||||
selectIds: selectResultsIds,
|
||||
selectTotal: selectResultsTotal,
|
||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||
|
||||
export const { resultAdded } = resultsSlice.actions;
|
||||
|
||||
export default resultsSlice.reducer;
|
||||
@@ -0,0 +1,54 @@
|
||||
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { RootState } from 'app/store';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { addImage } from '../gallerySlice';
|
||||
|
||||
type UploadImageConfig = {
|
||||
imageFile: File;
|
||||
};
|
||||
|
||||
export const uploadImage =
|
||||
(
|
||||
config: UploadImageConfig
|
||||
): ThunkAction<void, RootState, unknown, AnyAction> =>
|
||||
async (dispatch, getState) => {
|
||||
const { imageFile } = config;
|
||||
|
||||
const state = getState() as RootState;
|
||||
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
|
||||
const formData = new FormData();
|
||||
|
||||
formData.append('file', imageFile, imageFile.name);
|
||||
formData.append(
|
||||
'data',
|
||||
JSON.stringify({
|
||||
kind: 'init',
|
||||
})
|
||||
);
|
||||
|
||||
const response = await fetch(`${window.location.origin}/upload`, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
|
||||
const newImage: InvokeAI.Image = {
|
||||
uuid: uuidv4(),
|
||||
category: 'user',
|
||||
...image,
|
||||
};
|
||||
|
||||
dispatch(addImage({ image: newImage, category: 'user' }));
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(newImage));
|
||||
} else if (activeTabName === 'img2img') {
|
||||
dispatch(setInitialImage(newImage));
|
||||
}
|
||||
};
|
||||
@@ -1,95 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||
selectId: (image) => image.name,
|
||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||
});
|
||||
|
||||
type AdditionalUploadsState = {
|
||||
page: number;
|
||||
pages: number;
|
||||
isLoading: boolean;
|
||||
nextPage: number;
|
||||
};
|
||||
|
||||
export type UploadssState = ReturnType<
|
||||
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
|
||||
>;
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
nextPage: 0,
|
||||
isLoading: false,
|
||||
}),
|
||||
reducers: {
|
||||
uploadAdded: uploadsAdapter.addOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
* Received Upload Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Upload Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const images = items.map((image) => deserializeImageResponse(image));
|
||||
|
||||
uploadsAdapter.addMany(state, images);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { location, response } = action.payload;
|
||||
const { image_name, image_url, image_type, metadata, thumbnail_url } =
|
||||
response;
|
||||
|
||||
const uploadedImage: Image = {
|
||||
name: image_name,
|
||||
url: image_url,
|
||||
thumbnail: thumbnail_url,
|
||||
type: 'uploads',
|
||||
metadata,
|
||||
};
|
||||
|
||||
uploadsAdapter.addOne(state, uploadedImage);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectUploadsAll,
|
||||
selectById: selectUploadsById,
|
||||
selectEntities: selectUploadsEntities,
|
||||
selectIds: selectUploadsIds,
|
||||
selectTotal: selectUploadsTotal,
|
||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||
|
||||
export const { uploadAdded } = uploadsSlice.actions;
|
||||
|
||||
export default uploadsSlice.reducer;
|
||||
@@ -1,10 +1,9 @@
|
||||
import * as React from 'react';
|
||||
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
type ReactPanZoomProps = {
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
styleClass?: string;
|
||||
alt?: string;
|
||||
ref?: React.Ref<HTMLImageElement>;
|
||||
@@ -23,7 +22,6 @@ export default function ReactPanZoomImage({
|
||||
scaleY,
|
||||
}: ReactPanZoomProps) {
|
||||
const { centerView } = useTransformContext();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
return (
|
||||
<TransformComponent
|
||||
@@ -37,7 +35,7 @@ export default function ReactPanZoomImage({
|
||||
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
||||
width: '100%',
|
||||
}}
|
||||
src={getUrl(image.url)}
|
||||
src={image.url}
|
||||
alt={alt}
|
||||
ref={ref}
|
||||
className={styleClass ? styleClass : ''}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Tooltip,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { map } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
export const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const invocations = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocations
|
||||
);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
|
||||
},
|
||||
[dispatch, invocations]
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
|
||||
<MenuList>
|
||||
{map(invocations, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user