Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@@ -42,26 +42,24 @@ def simple_graph():
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
boards = None, # type: ignore
board_images = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
model_manager=None, # type: ignore
events=TestEventService(),
logger=None, # type: ignore
images=None, # type: ignore
latents=None, # type: ignore
boards=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
configuration = None, # type: ignore
processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore
)
def invoke_next(
g: GraphExecutionState, services: InvocationServices
) -> tuple[BaseInvocation, BaseInvocationOutput]:
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()
if n is None:
return (None, None)
@@ -130,9 +128,7 @@ def test_graph_state_expands_iterator(mock_services):
def test_graph_state_collects(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(
PromptCollectionTestInvocation(id="1", collection=list(test_prompts))
)
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="2"))
graph.add_node(PromptTestInvocation(id="3"))
graph.add_node(CollectInvocation(id="4"))
@@ -158,16 +154,10 @@ def test_graph_state_prepares_eagerly(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes
@@ -193,21 +183,13 @@ def test_graph_executes_depth_first(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(
create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")
)
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)