Compare commits

..

31 Commits

Author SHA1 Message Date
Lincoln Stein
fdad62e88b chore: add ".version" and ".last_model" to gitignore (#3208)
Mistakenly closed the previous pr.
2023-04-20 18:26:27 +01:00
Lincoln Stein
955c81acef Merge branch 'main' into patch-1 2023-04-20 18:26:06 +01:00
Lincoln Stein
e1058f3416 update CODEOWNERS for changed team composition (#3234)
Remove @mauwii and @keturn until they are able to reengage with the
development effort. @GreggHelt2 is designated co-codeowner for the
backend.
2023-04-20 17:19:10 +01:00
Sammy
edf16a253d Merge branch 'main' into patch-1 2023-04-20 14:16:10 +02:00
Lincoln Stein
46f5ef4100 Merge branch 'main' into dev/codeowner-fix-main 2023-04-19 22:40:56 +01:00
Lincoln Stein
b843255236 update CODEOWNERS for changed team composition 2023-04-19 17:37:48 -04:00
Alexandre D. Roberge
3a968e5072 Update NSFW.md
Outdated doc said to change the '.invokeai' file, but it's now named 'invokeai.init' afaik.
2023-04-18 21:18:32 -04:00
Lincoln Stein
fd80e84ea6 Merge branch 'main' into patch-1 2023-04-18 19:14:28 -04:00
Lincoln Stein
4824237a98 Added CPU instruction for README (#3225)
Since the change itself is quite straight-forward, I'll just describe
the context. Tried using automatic installer on my laptop, kept erroring
out on line 140-something of installer.py, "ERROR: Can not perform a
'--user' install. User site-packages are not visible in this
virtualenv."
Got tired of of fighting with pip so moved on to command line install.
Worked immediately, but at the time lacked instruction for CPU, so
instead of opening any helpful hyperlinks in the readme, took a few
minutes to grab the link from installer.py - thus this pr.
2023-04-18 19:07:37 -04:00
Leo Pasanen
2c9a05eb59 Added CPU instruction for README 2023-04-18 18:46:55 +03:00
blessedcoolant
ecb5bdaf7e [bug] #3218 HuggingFace API off when --no-internet set (#3219)
#3218 

Huggingface API will not be queried if --no-internet flag is set
2023-04-18 14:34:34 +12:00
Tim Cabbage
f6cdff2c5b [bug] #3218 HuggingFace API off when --no-internet set
https://github.com/invoke-ai/InvokeAI/issues/3218

Huggingface API will not be queried if --no-internet flag is set
2023-04-17 16:53:31 +02:00
Eugene
63d10027a4 nodes: invocation queue item - make more pydantic 2023-04-16 09:39:33 -04:00
Eugene
ef0773b8a3 nodes: set default for InvocationQueueItem.invoke_all 2023-04-16 09:39:33 -04:00
Eugene
3daaddf15b nodes: remove duplicate LatentsToLatentsInvocation 2023-04-16 09:39:33 -04:00
Eugene
570c3fe690 nodes: ensure Graph and GraphExecutionState ids are cast to str on instantiation 2023-04-16 09:39:33 -04:00
Eugene
cbd1a7263a nodes: fix typing of GraphExecutionState.id 2023-04-16 09:39:33 -04:00
Eugene
7fc5fbd4ce nodes: convert InvocationQueueItem to Pydantic class 2023-04-16 09:39:33 -04:00
Eugene Brodsky
6f6de402ad make InvocationQueueItem serializable 2023-04-16 09:39:33 -04:00
Sammy
281662a6e1 chore: add ".version" and ".last_model" to gitignore
Mistakenly closed the previous pr
2023-04-15 21:46:47 +02:00
SammCheese
50eb02f68b chore(ui): build 2023-04-15 20:45:17 +10:00
SammCheese
d73f3adc43 moving shouldHidePreview from gallery to ui slice. 2023-04-15 20:45:17 +10:00
SammCheese
116107f464 chore(ui): build 2023-04-15 20:45:17 +10:00
SammCheese
da44bb1707 rename setter 2023-04-15 20:45:17 +10:00
SammCheese
f43aed677e chore(ui): build 2023-04-15 20:45:17 +10:00
SammCheese
0d051aaae2 rename hidden variable to something more descriptive 2023-04-15 20:45:17 +10:00
SammCheese
e4e48ff995 i forgor to push the locale 2023-04-15 20:45:17 +10:00
SammCheese
442a6bffa4 feat: add "Hide Preview" Button 2023-04-15 20:45:17 +10:00
Kyle Schouviller
23d65e7162 [nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)
* Add latent to latent (img2img equivalent)
Fix a CLI bug with multiple links per node

* Using "latents" instead of "latent"

* [nodes] In-progress implementation of graph library

* Add linking to CLI for graph nodes (still broken)

* Fix subgraph execution, fix subgraph linking in CLI

* Fix LatentsToLatents
2023-04-14 06:41:06 +00:00
blessedcoolant
024fd54d0b Fixed a Typo. (#3190) 2023-04-14 14:33:31 +12:00
Nicholas Körfer
c44c19e911 Fixed a Typo. 2023-04-13 17:42:34 +02:00
287 changed files with 1426 additions and 10649 deletions

14
.github/CODEOWNERS vendored
View File

@@ -1,16 +1,16 @@
# continuous integration
/.github/workflows/ @mauwii @lstein @blessedcoolant
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
/mkdocs.yml @lstein @mauwii @blessedcoolant
/docs/ @lstein @tildebyte @blessedcoolant
/mkdocs.yml @lstein @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration
/pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein @blessedcoolant
/pyproject.toml @lstein @blessedcoolant
/docker/ @lstein @blessedcoolant
/scripts/ @ebr @lstein
/installer/ @lstein @ebr
/invokeai/assets @lstein @ebr
@@ -22,11 +22,11 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein
# generation, model management, postprocessing
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr @mauwii
/invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant

2
.gitignore vendored
View File

@@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
configs/models.user.yaml
config/models.user.yml
invokeai.init
.version
.last_model
# ignore the Anaconda/Miniconda installer used while building Docker image
anaconda.sh

View File

@@ -84,7 +84,7 @@ installing lots of models.
6. Wait while the installer does its thing. After installing the software,
the installer will launch a script that lets you configure InvokeAI and
select a set of starting image generaiton models.
select a set of starting image generation models.
7. Find the folder that InvokeAI was installed into (it is not the
same as the unpacked zip file directory!) The default location of this
@@ -148,6 +148,11 @@ not supported.
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
_For non-GPU systems:_
```terminal
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
```sh

View File

@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
At installation time, InvokeAI will ask whether the checker should be
activated by default (neither argument given on the command line). The
response is stored in the InvokeAI initialization file (usually
`.invokeai` in your home directory). You can change the default at any
`invokeai.init` in your home directory). You can change the default at any
time by opening this file in a text editor and commenting or
uncommenting the line `--nsfw_checker`.

View File

@@ -3,12 +3,14 @@
import os
from argparse import Namespace
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
@@ -69,6 +71,9 @@ class ApiDependencies:
latents=latents,
images=images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
@@ -76,6 +81,8 @@ class ApiDependencies:
restoration=RestorationServices(config),
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
@staticmethod

View File

@@ -1,8 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
import uuid
from fastapi import Path, Query, Request, UploadFile
@@ -10,7 +8,6 @@ from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@@ -43,47 +40,32 @@ async def get_thumbnail(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
},
status_code=201
)
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
async def upload_image(file: UploadFile, request: Request):
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code=415)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(image_path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
return res
},
)
@images_router.get(
"/",

View File

@@ -7,11 +7,40 @@ from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..models.image import ImageField
from ..services.graph import GraphExecutionState
from ..invocations.baseinvocation import BaseInvocation
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
help=field.field_info.description,
)
def add_parsers(
subparsers,
commands: list[type],
@@ -36,30 +65,26 @@ def add_parsers(
if name in exclude_fields:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
add_field_argument(command_parser, name, field)
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description,
)
def add_graph_parsers(
subparsers,
graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
):
for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description)
if add_arguments is not None:
add_arguments(command_parser)
# Add arguments for inputs
for exposed_input in graph.exposed_inputs:
node = graph.graph.get_node(exposed_input.node_path)
field = node.__fields__[exposed_input.field]
default_override = getattr(node, exposed_input.field)
add_field_argument(command_parser, exposed_input.alias, field, default_override)
class CliContext:
@@ -67,17 +92,38 @@ class CliContext:
session: GraphExecutionState
parser: argparse.ArgumentParser
defaults: dict[str, Any]
graph_nodes: dict[str, str]
nodes_added: list[str]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker
self.session = session
self.parser = parser
self.defaults = dict()
self.graph_nodes = dict()
self.nodes_added = list()
def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session
def reset(self):
self.session = self.invoker.create_execution_state()
self.graph_nodes = dict()
self.nodes_added = list()
# Leave defaults unchanged
def add_node(self, node: BaseInvocation):
self.get_session()
self.session.graph.add_node(node)
self.nodes_added.append(node.id)
self.invoker.services.graph_execution_manager.set(self.session)
def add_edge(self, edge: Edge):
self.get_session()
self.session.add_edge(edge)
self.invoker.services.graph_execution_manager.set(self.session)
class ExitCli(Exception):
"""Exception to exit the CLI"""

View File

@@ -13,17 +13,20 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@@ -58,7 +61,7 @@ def add_invocation_args(command_parser):
)
def get_command_parser() -> argparse.ArgumentParser:
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
@@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"])
# Create subparsers for exposed CLI graphs
# TODO: add a way to identify these graphs
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
return parser
class NodeField():
alias: str
node_path: str
field: str
field_type: type
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
self.alias = alias
self.node_path = node_path
self.field = field
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the inputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the outputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context)
matching_fields = set(afields.keys()).intersection(bfields.keys())
@@ -98,14 +153,14 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
edges = [
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
)
for field in matching_fields
for alias in matching_fields
]
return edges
@@ -158,6 +213,9 @@ def invoke_cli():
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
@@ -165,9 +223,12 @@ def invoke_cli():
restoration=RestorationServices(config),
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser()
parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$')
@@ -185,11 +246,12 @@ def invoke_cli():
try:
# Refresh the state of the session
history = list(get_graph_execution_history(context.session))
#history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
# Split the command for piping
cmds = cmd_input.split("|")
start_id = len(history)
start_id = len(context.nodes_added)
current_id = start_id
new_invocations = list()
for cmd in cmds:
@@ -205,8 +267,24 @@ def invoke_cli():
args[field_name] = field_default
# Parse invocation
args["id"] = current_id
command = CliCommand(command=args)
command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
command = CliCommand(command=args)
if command is None:
continue
# Run any CLI commands immediately
if isinstance(command.command, BaseCommand):
@@ -217,6 +295,7 @@ def invoke_cli():
command.command.run(context)
continue
# TODO: handle linking with library graphs
# Pipe previous command output (if there was a previous command)
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
@@ -229,7 +308,7 @@ def invoke_cli():
else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.command
from_node, command.command, context
)
edges.extend(matching_edges)
@@ -242,7 +321,7 @@ def invoke_cli():
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges(
link_node, command.command
link_node, command.command, context
)
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
@@ -256,12 +335,14 @@ def invoke_cli():
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
# TODO: handle missing input/output
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
node_input = get_node_inputs(command.command, context)[link[2]]
edges.append(
Edge(
source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
)
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
)
)
@@ -270,10 +351,10 @@ def invoke_cli():
current_id = current_id + 1
# Add the node to the session
context.session.add_node(command.command)
context.add_node(command.command)
for edge in edges:
print(edge)
context.session.add_edge(edge)
context.add_edge(edge)
# Execute all remaining nodes
invoke_all(context)
@@ -285,7 +366,7 @@ def invoke_cli():
except SessionError:
# Start a new session
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
context.reset()
except ExitCli:
break

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class CvInvocationConfig(BaseModel):
@@ -56,9 +56,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -9,9 +9,9 @@ from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.invocations.util.get_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
@@ -76,7 +76,6 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model_name = model["model_name"]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
@@ -96,22 +95,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[self.id]
invocation = graph_execution_state.execution_graph.get_node(self.id)
metadata = {
"session": context.graph_execution_state_id,
"source_id": source_id,
"invocation": invocation.dict()
}
context.services.images.save(image_type, image_name, generate_output.image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@@ -158,7 +144,6 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
outputs = Img2Img(model).generate(
prompt=self.prompt,
@@ -183,11 +168,9 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
@@ -235,8 +218,7 @@ class InpaintInvocation(ImageToImageInvocation):
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
model = choose_model(context.services.model_manager, self.model)
outputs = Inpaint(model).generate(
prompt=self.prompt,
@@ -261,9 +243,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -9,12 +9,7 @@ from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
class PILInvocationConfig(BaseModel):
@@ -27,70 +22,51 @@ class PILInvocationConfig(BaseModel):
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
#fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
#fmt: on
class Config:
schema_extra = {
"required": [
"type",
"image",
"width",
"height",
'required': [
'type',
'image',
]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
image_field = ImageField(image_name=image_name, image_type=image_type)
return ImageOutput(image=image_field, width=image.width, height=image.height)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
#fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
# fmt: on
#fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
type: Literal["load_image"] = "load_image"
# # TODO: this isn't really necessary anymore
# class LoadImageInvocation(BaseInvocation):
# """Load an image from a filename and provide it as output."""
# #fmt: off
# type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# # Inputs
# image_type: ImageType = Field(description="The type of the image")
# image_name: str = Field(description="The name of the image")
# #fmt: on
# def invoke(self, context: InvocationContext) -> ImageOutput:
# return ImageOutput(
# image_type=self.image_type,
# image_name=self.image_name,
# image=result_image
# )
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
class ShowImageInvocation(BaseInvocation):
@@ -110,17 +86,16 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
#fmt: off
type: Literal["crop"] = "crop"
# Inputs
@@ -129,7 +104,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -145,16 +120,15 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=image_crop
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
# fmt: off
#fmt: off
type: Literal["paste"] = "paste"
# Inputs
@@ -163,7 +137,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@@ -196,22 +170,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=new_image
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
# fmt: off
#fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@@ -226,22 +199,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask, self.dict())
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
# fmt: off
#fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -258,23 +231,22 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -290,24 +262,23 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -327,7 +298,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,11 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
@@ -18,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
@@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
@@ -166,7 +171,7 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
@@ -180,7 +185,7 @@ class TextToLatentsInvocation(BaseInvocation):
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
@@ -190,7 +195,7 @@ class TextToLatentsInvocation(BaseInvocation):
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
@@ -287,7 +292,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
@@ -350,9 +355,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
#fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)

View File

@@ -6,7 +6,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@@ -44,9 +44,7 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -8,7 +8,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class UpscaleInvocation(BaseInvocation):
@@ -49,9 +49,7 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,14 +1,11 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
return model_manager.get_model(model_name)
else:
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()

View File

@@ -1,26 +1,11 @@
from typing import Any, Optional, Dict
from typing import Optional
from pydantic import BaseModel, Field
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session: Optional[str] = Field(description="The session that generated this image")
source_id: Optional[str] = Field(
description="The source id of the invocation that generated this image"
)
# TODO: figure out metadata
invocation: Optional[Dict[str, Any]] = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's general metadata"""
"""An image's metadata"""
created: int = Field(description="The creation timestamp of the image")
timestamp: float = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
default={}, description="The image's InvokeAI-specific metadata"
)
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")

View File

@@ -0,0 +1,56 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
def create_text_to_image() -> LibraryGraph:
return LibraryGraph(
id=default_text_to_image_graph_id,
name='t2i',
description='Converts text to an image',
graph=Graph(
nodes={
'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512),
'3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'),
'5': LatentsToImageInvocation(id='5')
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height')
],
exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image')
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@@ -25,8 +25,7 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_dict: dict,
source_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@@ -36,8 +35,7 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
total_steps=total_steps,
@@ -45,43 +43,40 @@ class EventServiceBase:
)
def emit_invocation_complete(
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
self, graph_execution_state_id: str, invocation_id: str, result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
self, graph_execution_state_id: str, invocation_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
self, graph_execution_state_id: str, invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
source_id=source_id,
invocation_id=invocation_id,
),
)

View File

@@ -2,7 +2,6 @@
import copy
import itertools
import traceback
import uuid
from types import NoneType
from typing import (
@@ -17,7 +16,7 @@ from typing import (
)
import networkx as nx
from pydantic import BaseModel, validator
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from ..invocations import *
@@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationContext,
)
from .invocation_services import InvocationServices
class EdgeConnection(BaseModel):
@@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@@ -283,7 +281,8 @@ class Graph(BaseModel):
:raises InvalidEdgeError: the provided edge is invalid.
"""
if self._is_edge_valid(edge) and edge not in self.edges:
self._validate_edge(edge)
if edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
@@ -354,7 +353,7 @@ class Graph(BaseModel):
return True
def _is_edge_valid(self, edge: Edge) -> bool:
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
@@ -362,54 +361,53 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
return False
raise InvalidEdgeError("One or both nodes don't exist")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
return False
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
return False
raise InvalidEdgeError(f'Fields are incompatible')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
return False
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
return False
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
return False
raise InvalidEdgeError(f'Collector output type does not match collector input type')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
return False
raise InvalidEdgeError(f'Collector input type does not match collector output type')
return True
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -733,7 +731,7 @@ class Graph(BaseModel):
for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
@@ -750,9 +748,7 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(
description="The id of the execution state", default_factory=uuid.uuid4
)
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
@@ -858,7 +854,8 @@ class GraphExecutionState(BaseModel):
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
node_ids = set(self.graph.nx_graph_flat().nodes)
return self.has_error() or all((k in self.executed for k in node_ids))
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
@@ -946,11 +943,11 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph()
g = self.graph.nx_graph_flat()
collectors = (
n
for n in self.graph.nodes
if isinstance(self.graph.nodes[n], CollectInvocation)
if isinstance(self.graph.get_node(n), CollectInvocation)
)
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
@@ -962,7 +959,7 @@ class GraphExecutionState(BaseModel):
iterators = [
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.nodes[n], IterateInvocation)
if isinstance(self.graph.get_node(n), IterateInvocation)
]
return iterators
@@ -1098,7 +1095,9 @@ class GraphExecutionState(BaseModel):
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge):
try:
self.graph._validate_edge(edge)
except InvalidEdgeError:
return False
# Invalid if destination has already been prepared or executed
@@ -1144,4 +1143,52 @@ class GraphExecutionState(BaseModel):
self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v
@root_validator
def validate_exposed_nodes(cls, values):
graph = values['graph']
# Validate exposed inputs
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values
GraphInvocation.update_forward_refs()

View File

@@ -2,17 +2,16 @@
import datetime
import os
import json
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, List, Union
from typing import Callable, Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel, Json
from pydantic import BaseModel
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
@@ -43,7 +42,7 @@ class ImageStorageBase(ABC):
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
pass
@abstractmethod
@@ -101,8 +100,6 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
page_of_images.append(
ImageResponse(
image_type=image_type.value,
@@ -112,10 +109,9 @@ class DiskImageStorage(ImageStorageBase):
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(path)),
timestamp=os.path.getctime(path),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
)
@@ -154,11 +150,10 @@ class DiskImageStorage(ImageStorageBase):
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
print(metadata)
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, metadata
image, "", image_subpath, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
@@ -167,7 +162,6 @@ class DiskImageStorage(ImageStorageBase):
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
return image_path
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)

View File

@@ -1,30 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from abc import ABC, abstractmethod
from queue import Queue
import time
from pydantic import BaseModel, Field
# TODO: make this serializable
class InvocationQueueItem:
# session_id: str
graph_execution_state_id: str
invocation_id: str
invoke_all: bool
timestamp: float
def __init__(
self,
# session_id: str,
graph_execution_state_id: str,
invocation_id: str,
invoke_all: bool = False,
):
# self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id
self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)
class InvocationQueueABC(ABC):

View File

@@ -19,6 +19,7 @@ class InvocationServices:
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
@@ -29,6 +30,7 @@ class InvocationServices:
latents: LatentsStorageBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
@@ -38,6 +40,7 @@ class InvocationServices:
self.latents = latents
self.images = images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration

View File

@@ -43,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node to provide to cliepnts (the prepared node is not as useful)
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id
invocation_id=invocation.id,
)
# Invoke
@@ -79,8 +75,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id,
invocation_id=invocation.id,
result=outputs.dict(),
)
@@ -104,8 +99,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
source_id=source_id,
invocation_id=invocation.id,
error=error,
)

View File

@@ -1,25 +1,23 @@
import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
from sqlalchemy import create_engine, String, TEXT, Engine, select
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session
T = TypeVar("T", bound=BaseModel)
class Base(DeclarativeBase):
pass
sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_engine: Engine
# _table: ??? # TODO: figure out how to type this
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__()
@@ -27,79 +25,72 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._engine = create_engine(f"sqlite+pysqlite:///{self._filename}")
self._create_table()
def _create_table(self):
# dynamically create the ORM model class to avoid name collisions
# cannot access `self.__orig_class__` in `__init__` or `__new__` so
# format the table name into the class name
pascal_table_name = self._table_name.replace("_", " ").title()
pascal_table_name = pascal_table_name.replace(" ", "")
table_dict = dict(
__tablename__=self._table_name,
id=mapped_column(String, primary_key=True),
item=mapped_column(TEXT, nullable=False),
)
self._table = type(pascal_table_name, (Base,), table_dict)
Base.metadata.create_all(self._engine)
with self._lock:
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
self._conn.commit()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
session = Session(self._engine)
item_id = str(getattr(item, self._id_field))
new_item = self._table(id=item_id, item=item.json())
session.merge(new_item)
session.commit()
session.close()
with self._lock:
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
session = Session(self._engine)
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
item = session.get(self._table, str(id))
session.close()
if not item:
if not result:
return None
return self._parse_item(item.item)
return self._parse_item(result[0])
def delete(self, id: str):
session = Session(self._engine)
item = session.get(self._table, id)
session.delete(item)
session.commit()
session.close()
with self._lock:
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
session = Session(self._engine)
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
stmt = select(self._table.item).limit(per_page).offset(page * per_page)
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0]), result))
items = list(map(lambda r: self._parse_item(r[0]), result))
count = session.query(self._table.item).count()
session.commit()
session.close()
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
pageCount = int(count / per_page) + 1
@@ -110,19 +101,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
session = Session(self._engine)
with self._lock:
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
stmt = (
session.query(self._table)
.where(self._table.item.like(f"%{query}%"))
.limit(per_page)
.offset(page * per_page)
)
items = list(map(lambda r: self._parse_item(r[0]), result))
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0].item), result))
count = session.query(self._table.item).count()
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
pageCount = int(count / per_page) + 1

View File

@@ -1,4 +1,3 @@
from re import S
import torch
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
@@ -21,18 +20,12 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[id]
invocation = graph_execution_state.execution_graph.get_node(id)
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
invocation_dict=invocation.dict(),
source_id=source_id,
progress_image={"width": width, "height": height, "dataURL": dataURL},
step=step,
total_steps=steps,
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
)

View File

@@ -41,7 +41,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt)
if metadata:
info.add_text("invokeai", json.dumps(metadata))
info.add_text("sd-metadata", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
return path

View File

@@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
else:
elif Globals.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")
@@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
return self.concept_list
def get_concept_model_path(self, concept_name: str) -> str:
"""

View File

@@ -6,5 +6,3 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@@ -3,8 +3,4 @@ dist/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@@ -1,16 +1,10 @@
# InvokeAI Web UI
- [InvokeAI Web UI](#invokeai-web-ui)
- [Stack](#stack)
- [Contributing](#contributing)
- [Dev Environment](#dev-environment)
- [Production builds](#production-builds)
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
Code in `invokeai/frontend/web/` if you want to have a look.
## Stack
## Details
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
@@ -38,7 +32,7 @@ Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
### Production builds

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
:root {
--chakra-vh: 100vh;
}

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
<link rel="stylesheet" href="./assets/index-5483945c.css">
</head>

View File

@@ -8,7 +8,6 @@
"darkTheme": "داكن",
"lightTheme": "فاتح",
"greenTheme": "أخضر",
"text2img": "نص إلى صورة",
"img2img": "صورة إلى صورة",
"unifiedCanvas": "لوحة موحدة",
"nodes": "عقد",

View File

@@ -7,7 +7,6 @@
"darkTheme": "Dunkel",
"lightTheme": "Hell",
"greenTheme": "Grün",
"text2img": "Text zu Bild",
"img2img": "Bild zu Bild",
"nodes": "Knoten",
"langGerman": "Deutsch",

View File

@@ -505,7 +505,9 @@
"info": "Info",
"deleteImage": "Delete Image",
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel"
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
},
"settings": {
"models": "Models",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Oscuro",
"lightTheme": "Claro",
"greenTheme": "Verde",
"text2img": "Texto a Imagen",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Nodos",
@@ -70,7 +69,11 @@
"langHebrew": "Hebreo",
"pinOptionsPanel": "Pin del panel de opciones",
"loading": "Cargando",
"loadingInvokeAI": "Cargando invocar a la IA"
"loadingInvokeAI": "Cargando invocar a la IA",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar"
},
"gallery": {
"generations": "Generaciones",
@@ -404,7 +407,8 @@
"none": "ninguno",
"pickModelType": "Elige el tipo de modelo",
"v2_768": "v2 (768px)",
"addDifference": "Añadir una diferencia"
"addDifference": "Añadir una diferencia",
"scanForModels": "Buscar modelos"
},
"parameters": {
"images": "Imágenes",
@@ -574,7 +578,7 @@
"autoSaveToGallery": "Guardar automáticamente en galería",
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
"limitStrokesToBox": "Limitar trazos a la caja",
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
"clearCanvasHistory": "Limpiar historial de lienzo",
"clearHistory": "Limpiar historial",
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Sombre",
"lightTheme": "Clair",
"greenTheme": "Vert",
"text2img": "Texte en image",
"img2img": "Image en image",
"unifiedCanvas": "Canvas unifié",
"nodes": "Nœuds",
@@ -47,7 +46,19 @@
"statusLoadingModel": "Chargement du modèle",
"statusModelChanged": "Modèle changé",
"discordLabel": "Discord",
"githubLabel": "Github"
"githubLabel": "Github",
"accept": "Accepter",
"statusMergingModels": "Mélange des modèles",
"loadingInvokeAI": "Chargement de Invoke AI",
"cancel": "Annuler",
"langEnglish": "Anglais",
"statusConvertingModel": "Conversion du modèle",
"statusModelConverted": "Modèle converti",
"loading": "Chargement",
"pinOptionsPanel": "Épingler la page d'options",
"statusMergedModels": "Modèles mélangés",
"txt2img": "Texte vers image",
"postprocessing": "Post-Traitement"
},
"gallery": {
"generations": "Générations",
@@ -518,5 +529,15 @@
"betaDarkenOutside": "Assombrir à l'extérieur",
"betaLimitToBox": "Limiter à la boîte",
"betaPreserveMasked": "Conserver masqué"
},
"accessibility": {
"uploadImage": "Charger une image",
"reset": "Réinitialiser",
"nextImage": "Image suivante",
"previousImage": "Image précédente",
"useThisParameter": "Utiliser ce paramètre",
"zoomIn": "Zoom avant",
"zoomOut": "Zoom arrière",
"showOptionsPanel": "Montrer la page d'options"
}
}

View File

@@ -125,7 +125,6 @@
"langSimplifiedChinese": "סינית",
"langUkranian": "אוקראינית",
"langSpanish": "ספרדית",
"text2img": "טקסט לתמונה",
"img2img": "תמונה לתמונה",
"unifiedCanvas": "קנבס מאוחד",
"nodes": "צמתים",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Scuro",
"lightTheme": "Chiaro",
"greenTheme": "Verde",
"text2img": "Testo a Immagine",
"img2img": "Immagine a Immagine",
"unifiedCanvas": "Tela unificata",
"nodes": "Nodi",
@@ -70,7 +69,11 @@
"loading": "Caricamento in corso",
"oceanTheme": "Oceano",
"langHebrew": "Ebraico",
"loadingInvokeAI": "Caricamento Invoke AI"
"loadingInvokeAI": "Caricamento Invoke AI",
"postprocessing": "Post Elaborazione",
"txt2img": "Testo a Immagine",
"accept": "Accetta",
"cancel": "Annulla"
},
"gallery": {
"generations": "Generazioni",
@@ -404,7 +407,8 @@
"v2_768": "v2 (768px)",
"none": "niente",
"addDifference": "Aggiungi differenza",
"pickModelType": "Scegli il tipo di modello"
"pickModelType": "Scegli il tipo di modello",
"scanForModels": "Cerca modelli"
},
"parameters": {
"images": "Immagini",
@@ -574,7 +578,7 @@
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
"saveBoxRegionOnly": "Salva solo l'area di selezione",
"limitStrokesToBox": "Limita i tratti all'area di selezione",
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
"clearCanvasHistory": "Cancella cronologia Tela",
"clearHistory": "Cancella la cronologia",
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
@@ -612,7 +616,7 @@
"copyMetadataJson": "Copia i metadati JSON",
"exitViewer": "Esci dal visualizzatore",
"zoomIn": "Zoom avanti",
"zoomOut": "Zoom Indietro",
"zoomOut": "Zoom indietro",
"rotateCounterClockwise": "Ruotare in senso antiorario",
"rotateClockwise": "Ruotare in senso orario",
"flipHorizontally": "Capovolgi orizzontalmente",

View File

@@ -11,7 +11,6 @@
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"text2img": "텍스트->이미지",
"unifiedCanvas": "통합 캔버스",
"langFrench": "Français",
"langGerman": "Deutsch",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Donker",
"lightTheme": "Licht",
"greenTheme": "Groen",
"text2img": "Tekst naar afbeelding",
"img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas",
"nodes": "Knooppunten",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Ciemny",
"lightTheme": "Jasny",
"greenTheme": "Zielony",
"text2img": "Tekst na obraz",
"img2img": "Obraz na obraz",
"unifiedCanvas": "Tryb uniwersalny",
"nodes": "Węzły",

View File

@@ -20,7 +20,6 @@
"langSpanish": "Espanhol",
"langRussian": "Русский",
"langUkranian": "Украї́нська",
"text2img": "Texto para Imagem",
"img2img": "Imagem para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nós",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Noite",
"lightTheme": "Dia",
"greenTheme": "Verde",
"text2img": "Texto Para Imagem",
"img2img": "Imagem Para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nódulos",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Темная",
"lightTheme": "Светлая",
"greenTheme": "Зеленая",
"text2img": "Изображение из текста (text2img)",
"img2img": "Изображение в изображение (img2img)",
"unifiedCanvas": "Универсальный холст",
"nodes": "Ноды",

View File

@@ -8,7 +8,6 @@
"darkTheme": "Темна",
"lightTheme": "Світла",
"greenTheme": "Зелена",
"text2img": "Зображення із тексту (text2img)",
"img2img": "Зображення із зображення (img2img)",
"unifiedCanvas": "Універсальне полотно",
"nodes": "Вузли",

View File

@@ -8,7 +8,6 @@
"darkTheme": "暗色",
"lightTheme": "亮色",
"greenTheme": "绿色",
"text2img": "文字到图像",
"img2img": "图像到图像",
"unifiedCanvas": "统一画布",
"nodes": "节点",

View File

@@ -33,7 +33,6 @@
"langBrPortuguese": "巴西葡萄牙語",
"langRussian": "俄語",
"langSpanish": "西班牙語",
"text2img": "文字到圖像",
"unifiedCanvas": "統一畫布"
}
}

View File

@@ -1,87 +0,0 @@
# Generated axios API client
- [Generated axios API client](#generated-axios-api-client)
- [Generation](#generation)
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
- [Generate the API client from JSON](#generate-the-api-client-from-json)
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
- [Generate the API client](#generate-the-api-client)
- [The generated client](#the-generated-client)
- [API client customisation](#api-client-customisation)
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
## Generation
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
### Generate the API client from the nodes web server
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
1. Start the nodes web server.
```bash
# from the repo root
python scripts/invoke-new.py --web
```
2. Generate the API client.
```bash
# from invokeai/frontend/web/
yarn api:web
```
### Generate the API client from JSON
The JSON can be acquired from the nodes web server, or with a python script.
#### Getting the JSON from the nodes web server
Start the nodes web server as described above, then download the file.
```bash
# from invokeai/frontend/web/
curl http://localhost:9090/openapi.json -o openapi.json
```
#### Getting the JSON with a python script
Run this python script from the repo root, so it can access the nodes server modules.
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
```bash
# from the repo root
python invokeai/app/util/generate_openapi_json.py
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
```
#### Generate the API client
Now we can generate the API client from the JSON.
```bash
# from invokeai/frontend/web/
yarn api:file
```
## The generated client
The client will be written to `invokeai/frontend/web/services/api/`:
- `axios` client
- TS types
- An easily parseable schema, which we can use to generate UI
## API client customisation
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.

View File

@@ -1,21 +0,0 @@
# Events
Events via `socket.io`
## `actions.ts`
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
Any reducer (or middleware) can respond to the actions.
## `middleware.ts`
Redux middleware for events.
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
## `types.ts`
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.

View File

@@ -1,17 +0,0 @@
# Node Editor Design
WIP
nodes
everything in `src/features/nodes/`
have a look at `state.nodes.invocation`
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
- `reactflow` sends changes to nodes/edges to redux
- to invoke, `buildNodesGraph()` state, then send this
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`

View File

@@ -1,29 +0,0 @@
# Package Scripts
WIP walkthrough of `package.json` scripts.
## `theme` & `theme:watch`
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
The CLI essentially monkeypatches Chakra's files in `node_modules`.
## `postinstall`
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@@ -1,7 +1,6 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
@@ -65,24 +64,9 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
}
declare function Invoke(props: InvokeProps): JSX.Element;
declare function Invoke(props: PropsWithChildren): JSX.Element;
export {
ThemeChanger,
@@ -90,7 +74,5 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@@ -5,10 +5,7 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@@ -44,10 +41,9 @@
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@reduxjs/toolkit": "^1.9.3",
"@reduxjs/toolkit": "^1.9.2",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
@@ -71,9 +67,7 @@
"react-redux": "^8.0.5",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
@@ -89,7 +83,6 @@
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
@@ -97,17 +90,13 @@
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",

View File

@@ -505,7 +505,9 @@
"info": "Info",
"deleteImage": "Delete Image",
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel"
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
},
"settings": {
"models": "Models",
@@ -522,10 +524,6 @@
"resetComplete": "Web UI has been reset. Refresh the page to reload."
},
"toast": {
"serverError": "Server Error",
"disconnected": "Disconnected from Server",
"connected": "Connected to Server",
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",

View File

@@ -13,42 +13,16 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { useAppSelector } from './storeHooks';
import { PropsWithChildren, useEffect } from 'react';
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
keepGUIAlive();
interface Props extends PropsWithChildren {
options: {
disabledPanels: string[];
disabledTabs: InvokeTabName[];
shouldTransformUrls?: boolean;
};
}
const App = (props: Props) => {
const App = (props: PropsWithChildren) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(setDisabledPanels(props.options.disabledPanels));
}, [dispatch, props.options.disabledPanels]);
useEffect(() => {
dispatch(setDisabledTabs(props.options.disabledTabs));
}, [dispatch, props.options.disabledTabs]);
useEffect(() => {
dispatch(
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
);
}, [dispatch, props.options.shouldTransformUrls]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');

View File

@@ -14,8 +14,6 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:
@@ -115,7 +113,7 @@ export declare type Metadata = SystemGenerationMetadata & {
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type _Image = {
export declare type Image = {
uuid: string;
url: string;
thumbnail: string;
@@ -126,23 +124,11 @@ export declare type _Image = {
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
/**
* ResultImage
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<_Image>;
images: Array<Image>;
};
/**
@@ -289,7 +275,7 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@@ -310,7 +296,7 @@ export declare type ErrorResponse = {
};
export declare type GalleryImagesResponse = {
images: Array<Omit<_Image, 'uuid'>>;
images: Array<Omit<Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};

View File

@@ -13,13 +13,9 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);

View File

@@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const {
@@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const {
@@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);

View File

@@ -34,9 +34,8 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
// setInitialImage,
setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@@ -147,8 +146,7 @@ const makeSocketIOListeners = (
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(initialImageSelected(newImage.uuid));
// dispatch(setInitialImage(newImage));
dispatch(setInitialImage(newImage));
break;
}
}
@@ -264,7 +262,7 @@ const makeSocketIOListeners = (
*/
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI._Image => {
const preparedImages = images.map((image): InvokeAI.Image => {
return {
uuid: uuidv4(),
...image,
@@ -336,7 +334,7 @@ const makeSocketIOListeners = (
if (
initialImage === url ||
(initialImage as InvokeAI._Image)?.url === url
(initialImage as InvokeAI.Image)?.url === url
) {
dispatch(clearInitialImage());
}

View File

@@ -29,8 +29,6 @@ export const socketioMiddleware = () => {
path: `${window.location.pathname}socket.io`,
});
socketio.disconnect();
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {

View File

@@ -2,35 +2,18 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer, {
GalleryState,
} from 'features/gallery/store/gallerySlice';
import resultsReducer, {
resultsAdapter,
ResultsState,
} from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer, {
LightboxState,
} from 'features/lightbox/store/lightboxSlice';
import generationReducer, {
GenerationState,
} from 'features/parameters/store/generationSlice';
import postprocessingReducer, {
PostprocessingState,
} from 'features/parameters/store/postprocessingSlice';
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { CanvasState } from 'features/canvas/store/canvasTypes';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@@ -46,21 +29,13 @@ import { CanvasState } from 'features/canvas/store/canvasTypes';
* The necesssary nested persistors with blacklists are configured below.
*/
/**
* Canvas slice persist blacklist
*/
const canvasBlacklist: (keyof CanvasState)[] = [
const canvasBlacklist = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
].map((blacklistItem) => `canvas.${blacklistItem}`);
canvasBlacklist.map((blacklistItem) => `canvas.${blacklistItem}`);
/**
* System slice persist blacklist
*/
const systemBlacklist: (keyof SystemState)[] = [
const systemBlacklist = [
'currentIteration',
'currentStatus',
'currentStep',
@@ -73,101 +48,30 @@ const systemBlacklist: (keyof SystemState)[] = [
'totalIterations',
'totalSteps',
'openModel',
'isCancelScheduled',
'sessionId',
'progressImage',
];
'cancelOptions.cancelAfter',
].map((blacklistItem) => `system.${blacklistItem}`);
systemBlacklist.map((blacklistItem) => `system.${blacklistItem}`);
/**
* Gallery slice persist blacklist
*/
const galleryBlacklist: (keyof GalleryState)[] = [
const galleryBlacklist = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
].map((blacklistItem) => `gallery.${blacklistItem}`);
galleryBlacklist.map((blacklistItem) => `gallery.${blacklistItem}`);
/**
* Lightbox slice persist blacklist
*/
const lightboxBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
lightboxBlacklist.map((blacklistItem) => `lightbox.${blacklistItem}`);
/**
* Nodes slice persist blacklist
*/
const nodesBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
nodesBlacklist.map((blacklistItem) => `nodes.${blacklistItem}`);
/**
* Generation slice persist blacklist
*/
const generationBlacklist: (keyof GenerationState)[] = [];
generationBlacklist.map((blacklistItem) => `generation.${blacklistItem}`);
/**
* Postprocessing slice persist blacklist
*/
const postprocessingBlacklist: (keyof PostprocessingState)[] = [];
postprocessingBlacklist.map(
(blacklistItem) => `postprocessing.${blacklistItem}`
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
);
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config below
*/
const resultsBlacklist: (keyof ResultsState)[] = [];
resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config below
*/
const uploadsBlacklist: (keyof NodesState)[] = [];
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
/**
* Models slice persist blacklist
*/
const modelsBlacklist: (keyof NodesState)[] = [];
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
/**
* UI slice persist blacklist
*/
const uiBlacklist: (keyof NodesState)[] = [];
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
const rootReducer = combineReducers({
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
uploads: uploadsReducer,
lightbox: lightboxReducer,
});
const rootPersistConfig = getPersistConfig({
@@ -176,40 +80,23 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
...galleryBlacklist,
...lightboxBlacklist,
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(dynamicMiddlewares),
}).concat(socketioMiddleware()),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [

View File

@@ -1,8 +0,0 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@@ -2,6 +2,7 @@ import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';
import {
@@ -14,7 +15,6 @@ import {
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
type ImageUploaderProps = {
@@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(imageUploaded({ formData: { file } }));
dispatch(uploadImage({ imageFile: file }));
},
[dispatch]
);
@@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(imageUploaded({ formData: { file } }));
dispatch(uploadImage({ imageFile: file }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@@ -1,160 +1,27 @@
// import WorkInProgress from './WorkInProgress';
// import ReactFlow, {
// applyEdgeChanges,
// applyNodeChanges,
// Background,
// Controls,
// Edge,
// Handle,
// Node,
// NodeTypes,
// OnEdgesChange,
// OnNodesChange,
// Position,
// } from 'reactflow';
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
// import 'reactflow/dist/style.css';
// import {
// Fragment,
// FunctionComponent,
// ReactNode,
// useCallback,
// useMemo,
// useState,
// } from 'react';
// import { OpenAPIV3 } from 'openapi-types';
// import { filter, map, reduce } from 'lodash';
// import {
// Box,
// Flex,
// FormControl,
// FormLabel,
// Input,
// Select,
// Switch,
// Text,
// NumberInput,
// NumberInputField,
// NumberInputStepper,
// NumberIncrementStepper,
// NumberDecrementStepper,
// Tooltip,
// chakra,
// Badge,
// Heading,
// VStack,
// HStack,
// Menu,
// MenuButton,
// MenuList,
// MenuItem,
// MenuItemOption,
// MenuGroup,
// MenuOptionGroup,
// MenuDivider,
// IconButton,
// } from '@chakra-ui/react';
// import { FaPlus } from 'react-icons/fa';
// import {
// FIELD_NAMES as FIELD_NAMES,
// FIELDS,
// INVOCATION_NAMES as INVOCATION_NAMES,
// INVOCATIONS,
// } from 'features/nodeEditor/constants';
// console.log('invocations', INVOCATIONS);
// const nodeTypes = reduce(
// INVOCATIONS,
// (acc, val, key) => {
// acc[key] = val.component;
// return acc;
// },
// {} as NodeTypes
// );
// console.log('nodeTypes', nodeTypes);
// // make initial nodes one of every node for now
// let n = 0;
// const initialNodes = map(INVOCATIONS, (i) => ({
// id: i.type,
// type: i.title,
// position: { x: (n += 20), y: (n += 20) },
// data: {},
// }));
// console.log('initialNodes', initialNodes);
// export default function NodesWIP() {
// const [nodes, setNodes] = useState<Node[]>([]);
// const [edges, setEdges] = useState<Edge[]>([]);
// const onNodesChange: OnNodesChange = useCallback(
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
// []
// );
// const onEdgesChange: OnEdgesChange = useCallback(
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
// []
// );
// return (
// <Box
// sx={{
// position: 'relative',
// width: 'full',
// height: 'full',
// borderRadius: 'md',
// }}
// >
// <ReactFlow
// nodeTypes={nodeTypes}
// nodes={nodes}
// edges={edges}
// onNodesChange={onNodesChange}
// onEdgesChange={onEdgesChange}
// >
// <Background />
// <Controls />
// </ReactFlow>
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
// {FIELD_NAMES.map((field) => (
// <Badge
// key={field}
// colorScheme={FIELDS[field].color}
// sx={{ userSelect: 'none' }}
// >
// {field}
// </Badge>
// ))}
// </HStack>
// <Menu>
// <MenuButton
// as={IconButton}
// aria-label="Options"
// icon={<FaPlus />}
// sx={{ position: 'absolute', top: 2, left: 2 }}
// />
// <MenuList>
// {INVOCATION_NAMES.map((name) => {
// const invocation = INVOCATIONS[name];
// return (
// <Tooltip
// key={name}
// label={invocation.description}
// placement="end"
// hasArrow
// >
// <MenuItem>{invocation.title}</MenuItem>
// </Tooltip>
// );
// })}
// </MenuList>
// </Menu>
// </Box>
// );
// }
export default {};
export default function NodesWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}

View File

@@ -14,8 +14,6 @@ const WorkInProgress = (props: WorkInProgressProps) => {
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}

View File

@@ -1,72 +0,0 @@
import { RootState } from 'app/store';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { find } from 'lodash';
import {
Graph,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildHiResNode, buildImg2ImgNode } from './nodes/image2Image';
import { buildIteration } from './nodes/iteration';
import { buildTxt2ImgNode } from './nodes/text2Image';
function mapTabToFunction(activeTabName: InvokeTabName) {
switch (activeTabName) {
case 'txt2img':
return buildTxt2ImgNode;
case 'img2img':
return buildImg2ImgNode;
default:
return buildTxt2ImgNode;
}
}
const buildBaseNode = (
state: RootState
): Record<string, TextToImageInvocation | ImageToImageInvocation> => {
const { activeTab } = state.ui;
const activeTabName = tabMap[activeTab];
return mapTabToFunction(activeTabName)(state);
};
type BuildGraphOutput = {
graph: Graph;
nodeIdsToSubscribe: string[];
};
export const buildGraph = (state: RootState): BuildGraphOutput => {
const { generation, postprocessing } = state;
const { iterations } = generation;
const { hiresFix, hiresStrength } = postprocessing;
const baseNode = buildBaseNode(state);
let graph: Graph = { nodes: baseNode };
const nodeIdsToSubscribe: string[] = [];
if (iterations > 1) {
graph = buildIteration({ graph, iterations });
}
if (hiresFix) {
const { node, edge } = buildHiResNode(
baseNode as Record<string, TextToImageInvocation>,
hiresStrength
);
graph = {
nodes: {
...graph.nodes,
...node,
},
edges: [...(graph.edges || []), edge],
};
nodeIdsToSubscribe.push(Object.keys(node)[0]);
}
console.log('buildGraph: ', graph);
return { graph, nodeIdsToSubscribe };
};

View File

@@ -1,6 +0,0 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@@ -1,28 +0,0 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
};
};

View File

@@ -1,98 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { _Image } from 'app/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
export const buildImg2ImgNode = (
state: RootState
): Record<string, ImageToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale,
sampler,
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
}
return {
[nodeId]: {
id: nodeId,
type: 'img2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
},
};
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@@ -1,81 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import {
Edge,
Graph,
ImageToImageInvocation,
IterateInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildImg2ImgNode } from './image2Image';
type BuildIteration = {
graph: Graph;
iterations: number;
};
const buildRangeNode = (
iterations: number
): Record<string, RangeInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'range',
start: 0,
stop: iterations,
step: 1,
},
};
};
const buildIterateNode = (): Record<string, IterateInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'iterate',
collection: [],
index: 0,
},
};
};
export const buildIteration = ({
graph,
iterations,
}: BuildIteration): Graph => {
const rangeNode = buildRangeNode(iterations);
const iterateNode = buildIterateNode();
const baseNode: Graph['nodes'] = graph.nodes;
const edges: Edge[] = [
{
source: {
field: 'collection',
node_id: Object.keys(rangeNode)[0],
},
destination: {
field: 'collection',
node_id: Object.keys(iterateNode)[0],
},
},
{
source: {
field: 'item',
node_id: Object.keys(iterateNode)[0],
},
destination: {
field: 'seed',
node_id: Object.keys(baseNode!)[0],
},
},
];
return {
nodes: {
...rangeNode,
...iterateNode,
...graph.nodes,
},
edges,
};
};

View File

@@ -1,43 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { TextToImageInvocation } from 'services/api';
export const buildTxt2ImgNode = (
state: RootState
): Record<string, TextToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
sampler,
seamless,
shouldRandomizeSeed,
} = generation;
// missing fields in TextToImageInvocation: strength, hires_fix
return {
[nodeId]: {
id: nodeId,
type: 'txt2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
},
};
};

View File

@@ -1,10 +1,8 @@
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
import React, { lazy, PropsWithChildren } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { buildMiddleware, store } from './app/store';
import { store } from './app/store';
import { persistor } from './persistor';
import { OpenAPI } from 'services/api';
import { InvokeTabName } from 'features/ui/store/tabMap';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
@@ -19,61 +17,18 @@ import Loading from './Loading';
// Localization
import './i18n';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
export default function Component({
apiUrl,
disabledPanels = [],
disabledTabs = [],
token,
children,
shouldTransformUrls,
}: Props) {
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
}
// reset dynamically added middlewares
resetMiddlewares();
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
// outside the provider. it's not needed until there is the possibility that we will change
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
}, [apiUrl, token]);
export default function Component(props: PropsWithChildren) {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
>
{children}
</App>
<App>{props.children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@@ -5,8 +5,6 @@ import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
@@ -15,6 +13,4 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@@ -1,7 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
@@ -26,7 +25,7 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
const intermediateImage = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null);
@@ -37,8 +36,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
tempImage.src = intermediateImage.url;
}, [intermediateImage]);
if (!intermediateImage?.boundingBox) return null;

View File

@@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
@@ -33,7 +32,6 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null;
@@ -42,12 +40,7 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
/>
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
);
} else if (isCanvasBaseLine(obj)) {
const line = (

View File

@@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
@@ -54,16 +53,11 @@ const IAICanvasStagingArea = (props: Props) => {
width,
height,
} = useAppSelector(selector);
const { getUrl } = useGetUrl();
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
x={x}
y={y}
/>
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
)}
{shouldShowStagingOutline && (
<Group>

View File

@@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
const { stageDimensions } = state;
@@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI._Image;
image: InvokeAI.Image;
}>
) => {
const { boundingBox, image } = action.payload;

View File

@@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI._Image;
image: InvokeAI.Image;
};
export type CanvasMaskLine = {
@@ -125,7 +125,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI._Image;
intermediateImage?: InvokeAI.Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;

View File

@@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
const { url, width, height } = image;
const newImage: InvokeAI._Image = {
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: shouldSaveToGallery ? 'result' : 'user',
...image,

View File

@@ -14,9 +14,8 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
// setInitialImage,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@@ -28,6 +27,7 @@ import {
} from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldHidePreview,
setShouldShowImageDetails,
} from 'features/ui/store/uiSlice';
import { useHotkeys } from 'react-hotkeys-hook';
@@ -39,6 +39,8 @@ import {
FaDownload,
FaExpand,
FaExpandArrowsAlt,
FaEye,
FaEyeSlash,
FaGrinStars,
FaQuoteRight,
FaSeedling,
@@ -46,15 +48,11 @@ import {
FaShareAlt,
FaTrash,
} from 'react-icons/fa';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
const currentImageButtonsSelector = createSelector(
[
@@ -64,7 +62,6 @@ const currentImageButtonsSelector = createSelector(
uiSelector,
lightboxSelector,
activeTabNameSelector,
selectedImageSelector,
],
(
system: SystemState,
@@ -72,8 +69,7 @@ const currentImageButtonsSelector = createSelector(
postprocessing,
ui,
lightbox,
activeTabName,
selectedImage
activeTabName
) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
system;
@@ -82,7 +78,7 @@ const currentImageButtonsSelector = createSelector(
const { isLightboxOpen } = lightbox;
const { shouldShowImageDetails } = ui;
const { shouldShowImageDetails, shouldHidePreview } = ui;
const { intermediateImage, currentImage } = gallery;
@@ -98,7 +94,7 @@ const currentImageButtonsSelector = createSelector(
shouldShowImageDetails,
activeTabName,
isLightboxOpen,
selectedImage,
shouldHidePreview,
};
},
{
@@ -125,32 +121,27 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
// currentImage,
currentImage,
isLightboxOpen,
activeTabName,
selectedImage,
shouldHidePreview,
} = useAppSelector(currentImageButtonsSelector);
const { getUrl, shouldTransformUrls } = useGetUrl();
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!selectedImage) return;
if (!currentImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(initialImageSelected(selectedImage.name));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {
if (!selectedImage) return;
if (!currentImage) return;
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
res.blob()
);
const blob = await fetch(currentImage.url).then((res) => res.blob());
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
@@ -164,26 +155,24 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleCopyImageLink = () => {
const url = selectedImage
? shouldTransformUrls
? getUrl(selectedImage.url)
: window.location.toString() + selectedImage.url
: '';
navigator.clipboard.writeText(url).then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
navigator.clipboard
.writeText(
currentImage ? window.location.toString() + currentImage.url : ''
)
.then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
});
};
useHotkeys(
'shift+i',
() => {
if (selectedImage) {
if (currentImage) {
handleClickUseAsInitialImage();
toast({
title: t('toast.sentToImageToImage'),
@@ -201,27 +190,28 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handlePreviewVisibility = () => {
dispatch(setShouldHidePreview(!shouldHidePreview));
};
const handleClickUseAllParameters = () => {
if (!selectedImage) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
if (!currentImage) return;
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
if (currentImage.metadata?.image.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (currentImage.metadata?.image.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}
};
useHotkeys(
'a',
() => {
if (
['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
)
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
) {
handleClickUseAllParameters();
toast({
@@ -240,18 +230,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUseSeed = () => {
selectedImage?.metadata &&
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
currentImage?.metadata &&
dispatch(setSeed(currentImage.metadata.image.seed));
};
useHotkeys(
's',
() => {
if (selectedImage?.metadata?.sd_metadata?.seed) {
if (currentImage?.metadata?.image?.seed) {
handleClickUseSeed();
toast({
title: t('toast.seedSet'),
@@ -269,19 +259,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUsePrompt = useCallback(() => {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
if (currentImage?.metadata?.image?.prompt) {
setBothPrompts(currentImage?.metadata?.image?.prompt);
}
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
useHotkeys(
'p',
() => {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
if (currentImage?.metadata?.image?.prompt) {
handleClickUsePrompt();
toast({
title: t('toast.promptSet'),
@@ -299,11 +289,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage]
[currentImage]
);
const handleClickUpscale = () => {
// selectedImage && dispatch(runESRGAN(selectedImage));
currentImage && dispatch(runESRGAN(currentImage));
};
useHotkeys(
@@ -327,7 +317,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
selectedImage,
currentImage,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@@ -337,7 +327,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleClickFixFaces = () => {
// selectedImage && dispatch(runFacetool(selectedImage));
currentImage && dispatch(runFacetool(currentImage));
};
useHotkeys(
@@ -361,7 +351,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
selectedImage,
currentImage,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@@ -374,10 +364,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
const handleSendToCanvas = () => {
if (!selectedImage) return;
if (!currentImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(setInitialCanvasImage(currentImage));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@@ -395,7 +385,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'i',
() => {
if (selectedImage) {
if (currentImage) {
handleClickShowImageDetails();
} else {
toast({
@@ -406,7 +396,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[selectedImage, shouldShowImageDetails]
[currentImage, shouldShowImageDetails]
);
const handleLightBox = () => {
@@ -467,13 +457,28 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={getUrl(selectedImage!.url)}>
<Link download={true} href={currentImage?.url}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
</Link>
</Flex>
</IAIPopover>
<IAIIconButton
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
tooltip={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
aria-label={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
isChecked={shouldHidePreview}
onClick={handlePreviewVisibility}
/>
<IAIIconButton
icon={<FaExpand />}
tooltip={
@@ -496,7 +501,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
isDisabled={!currentImage?.metadata?.image?.prompt}
onClick={handleClickUsePrompt}
/>
@@ -504,7 +509,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
isDisabled={!currentImage?.metadata?.image?.seed}
onClick={handleClickUseSeed}
/>
@@ -514,7 +519,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
currentImage?.metadata?.image?.type
)
}
onClick={handleClickUseAllParameters}
@@ -540,7 +545,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!selectedImage ||
!currentImage ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@@ -569,7 +574,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!selectedImage ||
!currentImage ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
@@ -591,15 +596,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
{/* <DeleteImageModal image={selectedImage}>
<DeleteImageModal image={currentImage}>
<IAIIconButton
icon={<FaTrash />}
tooltip={`${t('parameters.deleteImage')} (Del)`}
aria-label={`${t('parameters.deleteImage')} (Del)`}
isDisabled={!selectedImage || !isConnected || isProcessing}
isDisabled={!currentImage || !isConnected || isProcessing}
colorScheme="error"
/>
</DeleteImageModal> */}
</DeleteImageModal>
</Flex>
);
};

View File

@@ -4,20 +4,17 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
[gallerySelector],
(gallery) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: selectedImage || intermediateImage,
hasAnImageToDisplay: currentImage || intermediateImage,
};
},
{

View File

@@ -0,0 +1,21 @@
import { Flex } from '@chakra-ui/react';
import { FaEyeSlash } from 'react-icons/fa';
const CurrentImageHidden = () => {
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<FaEyeSlash size={'30vh'} />
</Flex>
);
};
export default CurrentImageHidden;

View File

@@ -1,46 +1,28 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { selectedImageSelector } from '../store/gallerySelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden';
export const imagesSelector = createSelector(
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
const { currentImage, intermediateImage } = gallery;
const { shouldShowImageDetails, shouldHidePreview } = ui;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
imageToDisplay,
shouldHidePreview,
};
},
{
@@ -51,9 +33,12 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const { shouldShowImageDetails, imageToDisplay } =
useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
const {
shouldShowImageDetails,
imageToDisplay,
isIntermediate,
shouldHidePreview,
} = useAppSelector(imagesSelector);
return (
<Flex
@@ -67,15 +52,13 @@ export default function CurrentImagePreview() {
>
{imageToDisplay && (
<Image
src={
imageToDisplay.isProgressImage
? imageToDisplay.url
: getUrl(imageToDisplay.url)
}
src={shouldHidePreview ? undefined : imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={
!imageToDisplay.isProgressImage ? (
shouldHidePreview ? (
<CurrentImageHidden />
) : !isIntermediate ? (
<CurrentImageFallback />
) : undefined
}
@@ -85,31 +68,27 @@ export default function CurrentImagePreview() {
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
{shouldShowImageDetails && imageToDisplay && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
</Box>
)}
</Flex>
);
}

View File

@@ -52,7 +52,7 @@ interface DeleteImageModalProps {
/**
* The image to delete.
*/
image?: InvokeAI._Image;
image?: InvokeAI.Image;
}
/**

View File

@@ -9,14 +9,11 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@@ -34,7 +31,6 @@ import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
interface HoverableImageProps {
image: InvokeAI.Image;
@@ -44,7 +40,7 @@ interface HoverableImageProps {
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@@ -59,8 +55,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, name, metadata } = image;
const { getUrl } = useGetUrl();
const { url, thumbnail, uuid, metadata } = image;
const [isHovered, setIsHovered] = useState<boolean>(false);
@@ -74,9 +69,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@@ -86,8 +82,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
image.metadata && dispatch(setSeed(image.metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@@ -97,11 +92,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(initialImageSelected(image.name));
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image));
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@@ -118,7 +122,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
metadata && dispatch(setAllParameters(metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@@ -128,13 +132,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
dispatch(setAllImageToImageParameters(metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@@ -153,18 +155,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
// e.dataTransfer.effectAllowed = 'move';
e.dataTransfer.setData('invokeai/imageUuid', uuid);
e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
};
return (
@@ -177,30 +177,28 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
isDisabled={image?.metadata?.image?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
isDisabled={image?.metadata?.image?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
@@ -211,9 +209,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
{/* <DeleteImageModal image={image}>
<DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal> */}
</DeleteImageModal>
</MenuItem>
</MenuList>
)}
@@ -221,7 +219,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={name}
key={uuid}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@@ -246,7 +244,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={getUrl(thumbnail || url)}
src={thumbnail || url}
loading="lazy"
sx={{
position: 'absolute',
@@ -292,7 +290,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
{/* <DeleteImageModal image={image}>
<DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@@ -300,7 +298,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal> */}
</DeleteImageModal>
</Box>
)}
</Box>

View File

@@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
import { requestImages } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
@@ -25,44 +25,9 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
resultsAdapter,
selectResultsAll,
selectResultsTotal,
} from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery;
return currentCategory === 'result'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
);
const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -70,7 +35,7 @@ const ImageGalleryContent = () => {
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const {
// images,
images,
currentCategory,
currentImageUuid,
shouldPinGallery,
@@ -78,24 +43,12 @@ const ImageGalleryContent = () => {
galleryGridTemplateColumns,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
// areMoreImagesAvailable,
areMoreImagesAvailable,
shouldUseSingleGalleryColumn,
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => {
if (currentCategory === 'result') {
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(receivedUploadImagesPage());
}
dispatch(requestImages(currentCategory));
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@@ -250,11 +203,11 @@ const ImageGalleryContent = () => {
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{images.map((image) => {
const { name } = image;
const isSelected = currentImageUuid === name;
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={name}
key={uuid}
image={image}
isSelected={isSelected}
/>
@@ -264,7 +217,6 @@ const ImageGalleryContent = () => {
<IAIButton
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
flexShrink={0}
>
{areMoreImagesAvailable

View File

@@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
@@ -19,7 +18,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@@ -46,7 +45,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
@@ -122,7 +120,7 @@ type ImageMetadataViewerProps = {
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
) => prev.image.uuid === next.image.uuid;
// TODO: Show more interesting information in this component.
@@ -139,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const {
cfg_scale,
@@ -162,23 +160,11 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
const metadataJSON = JSON.stringify(image.metadata, null, 2);
return (
<Flex
@@ -197,49 +183,18 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
@@ -333,14 +288,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setHeight(height))}
/>
)}
{/* {init_image_path && (
{init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
@@ -453,6 +408,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>

View File

@@ -7,16 +7,6 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@@ -85,18 +75,3 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@@ -1,17 +1,14 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { imageUploaded } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI._Image>;
images: Array<InvokeAI.Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
@@ -19,33 +16,16 @@ export type AddImagesPayload = {
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI._Image[];
images: InvokeAI.Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImage?: InvokeAI.Image;
currentImageUuid: string;
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
intermediateImage?: InvokeAI.Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
};
@@ -62,7 +42,6 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@@ -90,10 +69,7 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
@@ -148,7 +124,7 @@ export const gallerySlice = createSlice({
addImage: (
state,
action: PayloadAction<{
image: InvokeAI._Image;
image: InvokeAI.Image;
category: GalleryCategory;
}>
) => {
@@ -174,10 +150,7 @@ export const gallerySlice = createSlice({
setIntermediateImage: (
state,
action: PayloadAction<
InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
}
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
>
) => {
state.intermediateImage = action.payload;
@@ -279,31 +252,9 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@@ -1,149 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
deserializeImageField,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { getUrlAlt } from 'common/util/getUrl';
import { ImageMetadata } from 'services/api';
// import { deserializeImageField } from 'services/util/deserializeImageField';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
export const resultsAdapter = createEntityAdapter<Image>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
nextPage: number; // the next page to request
};
// export type ResultsState = ReturnType<
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
// >;
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, invocation, graph_execution_state_id, source_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session: graph_execution_state_id,
source_id,
invocation,
},
},
};
// const resultImage = deserializeImageField(result.image, invocation);
resultsAdapter.addOne(state, image);
}
});
},
});
// Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@@ -0,0 +1,54 @@
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { v4 as uuidv4 } from 'uuid';
import { addImage } from '../gallerySlice';
type UploadImageConfig = {
imageFile: File;
};
export const uploadImage =
(
config: UploadImageConfig
): ThunkAction<void, RootState, unknown, AnyAction> =>
async (dispatch, getState) => {
const { imageFile } = config;
const state = getState() as RootState;
const activeTabName = activeTabNameSelector(state);
const formData = new FormData();
formData.append('file', imageFile, imageFile.name);
formData.append(
'data',
JSON.stringify({
kind: 'init',
})
);
const response = await fetch(`${window.location.origin}/upload`, {
method: 'POST',
body: formData,
});
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: 'user',
...image,
};
dispatch(addImage({ image: newImage, category: 'user' }));
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
}
};

View File

@@ -1,95 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageUploaded } from 'services/thunks/image';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
export type UploadssState = ReturnType<
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
>;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
}),
reducers: {
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const { image_name, image_url, image_type, metadata, thumbnail_url } =
response;
const uploadedImage: Image = {
name: image_name,
url: image_url,
thumbnail: thumbnail_url,
type: 'uploads',
metadata,
};
uploadsAdapter.addOne(state, uploadedImage);
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@@ -1,10 +1,9 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI._Image;
image: InvokeAI.Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@@ -23,7 +22,6 @@ export default function ReactPanZoomImage({
scaleY,
}: ReactPanZoomProps) {
const { centerView } = useTransformContext();
const { getUrl } = useGetUrl();
return (
<TransformComponent
@@ -37,7 +35,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={getUrl(image.url)}
src={image.url}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@@ -1,47 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { map } from 'lodash';
import { RootState } from 'app/store';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocations = useAppSelector(
(state: RootState) => state.nodes.invocations
);
const addNode = useCallback(
(nodeType: string) => {
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
},
[dispatch, invocations]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocations, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

Some files were not shown because too many files have changed in this diff Show More