mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(app): document & clean up batch prep logic
This commit is contained in:
@@ -406,58 +406,143 @@ class IsFullResult(BaseModel):
|
||||
# region Util
|
||||
|
||||
|
||||
def create_graph_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, list[dict]], None, None]:
|
||||
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
that was applied to the graph.
|
||||
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
|
||||
field_values_json for each session.
|
||||
|
||||
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
|
||||
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
|
||||
the field.
|
||||
|
||||
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
|
||||
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
|
||||
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
|
||||
Each inner list now represents the substitution values for a single permutation (session).
|
||||
- For each permutation, substitute the values into the graph
|
||||
|
||||
This function is optimized for performance, as it is used to generate a large number of sessions at once.
|
||||
|
||||
Args:
|
||||
batch: The batch to generate sessions from
|
||||
maximum: The maximum number of sessions to generate
|
||||
|
||||
Returns:
|
||||
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
|
||||
generator will stop early if the maximum number of sessions is reached.
|
||||
"""
|
||||
|
||||
# TODO: Should this be a class method on Batch?
|
||||
|
||||
data: list[list[tuple[dict]]] = []
|
||||
batch_data_collection = batch.data if batch.data is not None else []
|
||||
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
|
||||
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
for batch_datum_list in batch_data_collection:
|
||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||
|
||||
node_field_values_to_zip: list[list[dict]] = []
|
||||
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
|
||||
for batch_datum in batch_datum_list:
|
||||
node_field_values = [
|
||||
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
|
||||
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
|
||||
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
# Zip the dicts together to create a list of dicts for each permutation
|
||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
# We serialize the graph and session once, then mutate the graph dict in place for each session.
|
||||
#
|
||||
# This sounds scary, but it's actually fine.
|
||||
#
|
||||
# The batch prep logic injects field values into the same fields for each generated session.
|
||||
#
|
||||
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
|
||||
# [
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 1},
|
||||
# {"node_path": "2", "field_name": "b", "value": 2},
|
||||
# {"node_path": "3", "field_name": "c", "value": 3},
|
||||
# ),
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 4},
|
||||
# {"node_path": "2", "field_name": "b", "value": 5},
|
||||
# {"node_path": "3", "field_name": "c", "value": 6},
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
|
||||
# No matter the complexity of the batch, this property holds true.
|
||||
#
|
||||
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
|
||||
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
|
||||
#
|
||||
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
|
||||
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
|
||||
# but this was also slow.
|
||||
#
|
||||
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
|
||||
# objects for each session.
|
||||
#
|
||||
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
|
||||
# dict as the session's graph.
|
||||
|
||||
# Dump the batch's graph to a dict once
|
||||
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
|
||||
# overwritten for each session by the mutated graph_as_dict.
|
||||
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
|
||||
count = 0
|
||||
|
||||
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
|
||||
# still limited by the maximum number of sessions.
|
||||
for _ in range(batch.runs):
|
||||
for d in product(*data):
|
||||
if count >= maximum:
|
||||
# We've reached the maximum number of sessions we may generate
|
||||
return
|
||||
|
||||
# Flatten the list of lists of dicts into a single list of dicts
|
||||
# TODO(psyche): Is the a more efficient way to do this?
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
|
||||
# The fields that are injected for each the same for all graphs. Therefore, we can mutate the graph dict
|
||||
# in place and then serialize it to json for each session. It's functionally the same as creating a new
|
||||
# graph dict for each session, but is more efficient.
|
||||
# Need a fresh ID for each session
|
||||
session_id = uuid_string()
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["id"] = session_id
|
||||
|
||||
for item in flat_node_field_values:
|
||||
graph_as_dict["nodes"][item["node_path"]][item["field_name"]] = item["value"]
|
||||
# Substitute the values into the graph
|
||||
for nfv in flat_node_field_values:
|
||||
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["graph"] = graph_as_dict
|
||||
yield (session_id, json.dumps(session_dict, default=to_jsonable_python), flat_node_field_values)
|
||||
|
||||
# Serialize the session and field values
|
||||
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
|
||||
session_json = json.dumps(session_dict, default=to_jsonable_python)
|
||||
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
|
||||
|
||||
# Yield the session_id, session_json, and field_values_json
|
||||
yield (session_id, session_json, field_values_json)
|
||||
|
||||
# Increment the count so we know when to stop
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring
|
||||
the overhead of actually generating them. Adapted from `create_sessions().
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
@@ -473,20 +558,63 @@ def calc_session_count(batch: Batch) -> int:
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> list[tuple]:
|
||||
values_to_insert: list[tuple] = []
|
||||
ValueToInsertTuple: TypeAlias = tuple[
|
||||
str, # queue_id
|
||||
str, # session (as stringified JSON)
|
||||
str, # session_id
|
||||
str, # batch_id
|
||||
str | None, # field_values (optional, as stringified JSON)
|
||||
int, # priority
|
||||
str | None, # workflow (optional, as stringified JSON)
|
||||
str | None, # origin (optional)
|
||||
str | None, # destination (optional)
|
||||
str | None, # retried_from_item_id (optional, this is always None for new items)
|
||||
]
|
||||
"""A type alias for the tuple of values to insert into the session queue table."""
|
||||
|
||||
|
||||
def prepare_values_to_insert(
|
||||
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
|
||||
) -> list[ValueToInsertTuple]:
|
||||
"""
|
||||
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
|
||||
`executemany` statement to insert multiple rows at once.
|
||||
|
||||
Args:
|
||||
queue_id: The ID of the queue to insert the items into
|
||||
batch: The batch to prepare the values for
|
||||
priority: The priority of the queue items
|
||||
max_new_queue_items: The maximum number of queue items to insert
|
||||
|
||||
Returns:
|
||||
A list of tuples to insert into the session queue table. Each tuple contains the following values:
|
||||
- queue_id
|
||||
- session (as stringified JSON)
|
||||
- session_id
|
||||
- batch_id
|
||||
- field_values (optional, as stringified JSON)
|
||||
- priority
|
||||
- workflow (optional, as stringified JSON)
|
||||
- origin (optional)
|
||||
- destination (optional)
|
||||
- retried_from_item_id (optional, this is always None for new items)
|
||||
"""
|
||||
|
||||
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
|
||||
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
|
||||
# this difference becomes noticeable.
|
||||
#
|
||||
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
|
||||
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
|
||||
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
|
||||
# not support by default. Apparently there are sets somewhere in the graph.
|
||||
|
||||
# The same workflow is used for all sessions in the batch - serialize it once
|
||||
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
|
||||
|
||||
for session_id, session_json, field_values in create_graph_nfv_tuples(batch, max_new_queue_items):
|
||||
# As a perf optimization, we can mutate the session_dict in place. This is safe because we dump it to json
|
||||
# as part of the tuple construction
|
||||
|
||||
field_values_json = json.dumps(field_values, default=to_jsonable_python) if field_values else None
|
||||
|
||||
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
values_to_insert.append(
|
||||
(
|
||||
queue_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchDatum,
|
||||
NodeFieldValue,
|
||||
calc_session_count,
|
||||
create_graph_nfv_tuples,
|
||||
create_session_nfv_tuples,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
@@ -42,7 +42,7 @@ def batch_graph() -> Graph:
|
||||
|
||||
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
|
||||
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
||||
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
|
||||
# 2 list[BatchDatum] * length 2 * 2 runs = 8
|
||||
assert len(t) == 8
|
||||
|
||||
@@ -90,28 +90,28 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph
|
||||
|
||||
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
|
||||
b = Batch(graph=batch_graph, data=batch_data_collection)
|
||||
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
|
||||
# 2 list[BatchDatum] * length 2 * 1 runs = 8
|
||||
assert len(t) == 4
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_without_batch(batch_graph):
|
||||
b = Batch(graph=batch_graph, runs=2)
|
||||
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
|
||||
# 2 runs
|
||||
assert len(t) == 2
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_without_batch_or_runs(batch_graph):
|
||||
b = Batch(graph=batch_graph)
|
||||
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
|
||||
# 1 run
|
||||
assert len(t) == 1
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_with_runs_and_max(batch_data_collection, batch_graph):
|
||||
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
||||
t = list(create_graph_nfv_tuples(batch=b, maximum=5))
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=5))
|
||||
# 2 list[BatchDatum] * length 2 * 2 runs = 8, but max is 5
|
||||
assert len(t) == 5
|
||||
|
||||
|
||||
Reference in New Issue
Block a user