Merge branch 'main' into JPPhoto-symmetry-enhancements

This commit is contained in:
Jonathan
2023-03-16 06:29:47 -05:00
committed by GitHub
11 changed files with 205 additions and 134 deletions

View File

@@ -10,6 +10,7 @@ from pydantic.fields import Field
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
@@ -92,7 +93,7 @@ async def get_session(
async def add_node(
session_id: str = Path(description="The id of the session"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The node to add"),
) -> str:
"""Adds a node to the graph"""
@@ -125,7 +126,7 @@ async def update_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The new node"),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
@@ -186,7 +187,7 @@ async def delete_node(
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@@ -228,9 +229,9 @@ async def delete_edge(
return Response(status_code=404)
try:
edge = (
EdgeConnection(node_id=from_node_id, field=from_field),
EdgeConnection(node_id=to_node_id, field=to_field),
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field)
)
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(

View File

@@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import EdgeConnection, GraphExecutionState
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
@@ -94,9 +94,9 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
edges = [
(
EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field),
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
)
for field in matching_fields
]
@@ -111,16 +111,15 @@ class SessionError(Exception):
def invoke_all(context: CliContext):
"""Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True)
while not context.session.is_complete():
while not context.get_session().is_complete():
# Wait some time
session = context.get_session()
time.sleep(0.1)
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
@@ -203,7 +202,7 @@ def invoke_cli():
continue
# Pipe previous command output (if there was a previous command)
edges = []
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
@@ -225,19 +224,19 @@ def invoke_cli():
matching_edges = generate_matching_edges(
link_node, command.command
)
matching_destinations = [e[1] for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations]
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
Edge(
source=EdgeConnection(node_id=link[1], field=link[0]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
),
)
)
)

View File

@@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
@@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
float(step) / float(self.steps),
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache

View File

@@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@@ -23,8 +26,9 @@ class EventServiceBase:
self,
graph_execution_state_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
percent: float,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
@@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
percent=percent,
total_steps=total_steps,
),
)

View File

@@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
return hash(f"{self.node_id}.{self.field}")
class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
@@ -194,7 +199,7 @@ class Graph(BaseModel):
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@@ -251,7 +256,7 @@ class Graph(BaseModel):
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def add_edge(self, edge: Edge) -> None:
"""Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid.
@@ -262,7 +267,7 @@ class Graph(BaseModel):
else:
raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def delete_edge(self, edge: Edge) -> None:
"""Deletes an edge from a graph"""
try:
@@ -280,7 +285,7 @@ class Graph(BaseModel):
# Validate all edges reference nodes in the graph
node_ids = set(
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
)
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
@@ -294,10 +299,10 @@ class Graph(BaseModel):
if not all(
(
are_connections_compatible(
self.get_node(e[0].node_id),
e[0].field,
self.get_node(e[1].node_id),
e[1].field,
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
)
for e in self.edges
)
@@ -328,58 +333,58 @@ class Graph(BaseModel):
return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge[0].node_id)
to_node = self.get_node(edge[1].node_id)
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id)
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
return False
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge[0].field, to_node, edge[1].field
from_node, edge.source.field, to_node, edge.destination.field
):
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
@@ -438,15 +443,15 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[1].node_id
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
edge[0],
EdgeConnection(
node_id=new_graph_node_path, field=edge[1].field
),
Edge(
source=edge.source,
destination=EdgeConnection(
node_id=new_graph_node_path, field=edge.destination.field
)
)
)
@@ -454,51 +459,51 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[0].node_id
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
EdgeConnection(
node_id=new_graph_node_path, field=edge[0].field
Edge(
source=EdgeConnection(
node_id=new_graph_node_path, field=edge.source.field
),
edge[1],
destination=edge.destination
)
)
def _get_input_edges(
self, node_path: str, field: Optional[str] = None
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
)
node_id = (
@@ -522,37 +527,37 @@ class Graph(BaseModel):
def _get_output_edges(
self, node_path: str, field: str
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field)
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
)
node_id = (
@@ -580,8 +585,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None:
inputs.append(new_input)
@@ -622,8 +627,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None:
inputs.append(new_input)
@@ -684,7 +689,7 @@ class Graph(BaseModel):
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
@@ -711,7 +716,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from(
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
@@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
@@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
input_collection_prepared_node_id = next(
n[1]
for n in iteration_node_map
if n[0] == input_collection_edge[0].node_id
if n[0] == input_collection_edge.source.node_id
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge[0].field
input_collection_prepared_node_output, input_collection_edge.source.field
)
self_iteration_count = len(input_collection)
@@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
new_edges = list()
for edge in input_edges:
for input_node_id in (
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
):
new_edge = (
EdgeConnection(node_id=input_node_id, field=edge[0].field),
EdgeConnection(node_id="", field=edge[1].field),
new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field),
)
new_edges.append(new_edge)
@@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
# Add new edges to execution graph
for edge in new_edges:
new_edge = (
edge[0],
EdgeConnection(node_id=new_node.id, field=edge[1].field),
new_edge = Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
)
self.execution_graph.add_edge(new_edge)
@@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
if isinstance(node, CollectInvocation):
output_collection = [
getattr(self.results[edge[0].node_id], edge[0].field)
getattr(self.results[edge.source.node_id], edge.source.field)
for edge in input_edges
if edge[1].field == "item"
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
setattr(node, edge[1].field, output_value)
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge.destination.field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping:
if edge.destination.node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
@@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
)
self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
)
self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def delete_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
)
self.graph.delete_edge(edge)

View File

@@ -497,7 +497,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result
def sample_to_lowres_estimated_image(self, samples):
@staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@@ -3,6 +3,9 @@ import math
import multiprocessing as mp
import os
import re
import io
import base64
from collections import abc
from inspect import isfunction
from pathlib import Path
@@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64

View File

@@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.10",
"compel==1.0.1",
"datasets",
"diffusers[torch]~=0.14",
"dnspython==2.2.1",

View File

@@ -105,17 +105,20 @@
// Start building nodes
var id = 1;
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
id++;
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
id++;
var upscaleNode = {"id": id.toString(), "type": "show_image" };
id++
nodes = {};
nodes[initialNode.id] = initialNode;
nodes[i2iNode.id] = i2iNode;
nodes[upscaleNode.id] = upscaleNode;
links = [
[{ "node_id": initialNode.id, field: "image" },
{ "node_id": upscaleNode.id, field: "image" }]
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
];
// expandSize = 128;
// for (var i = 0; i < 6; ++i) {

View File

@@ -1,15 +1,18 @@
from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
import pytest
# Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field)
)
# Tests
def test_connections_are_compatible():
@@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict():
g = Graph()
@@ -490,10 +493,10 @@ def test_graph_can_deserialize():
assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None
assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1'
assert g2.edges[0][0].field == 'image'
assert g2.edges[0][1].node_id == '2'
assert g2.edges[0][1].field == 'image'
assert g2.edges[0].source.node_id == '1'
assert g2.edges[0].source.field == 'image'
assert g2.edges[0].destination.node_id == '2'
assert g2.edges[0].destination.field == 'image'
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient

View File

@@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import EdgeConnection
from invokeai.app.services.graph import Edge, EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field))
class TestEvent: