Compare commits

...

4 Commits

Author SHA1 Message Date
psychedelicious
ae8459f221 tidy(app): document & clean up batch prep logic 2025-02-26 19:21:55 +10:00
psychedelicious
b73f15d4a2 tidy(app): remove timing debug logs 2025-02-26 17:42:25 +10:00
psychedelicious
5e1974d924 perf(app): optimise batch prep logic even more
Found another place where we deepcopy a dict, but it is safe to mutate.

Restructured the prep logic a bit to support this. Updated tests to use the new structure.
2025-02-26 17:42:25 +10:00
psychedelicious
3011dfca16 perf(app): optimise batch prep logic
- Avoid pydantic models when dict manipulation works
- Avoid extraneous deep copies when we can safely mutate
- Avoid NamedTuple construct and its overhead
- Fix tests to use altered function signatures
- Remove extraneous populate_graph function
2025-02-26 17:42:25 +10:00
3 changed files with 229 additions and 113 deletions

View File

@@ -1,7 +1,7 @@
import datetime
import json
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
from pydantic import (
AliasChoices,
@@ -406,61 +406,143 @@ class IsFullResult(BaseModel):
# region Util
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
"""
Populates the given graph with the given batch data items.
"""
graph_clone = graph.model_copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
field_values_json for each session.
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
the field.
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
Each inner list now represents the substitution values for a single permutation (session).
- For each permutation, substitute the values into the graph
This function is optimized for performance, as it is used to generate a large number of sessions at once.
Args:
batch: The batch to generate sessions from
maximum: The maximum number of sessions to generate
Returns:
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
generator will stop early if the maximum number of sessions is reached.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[NodeFieldValue]]] = []
data: list[list[tuple[dict]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum_list in batch_data_collection:
node_field_values_to_zip: list[list[dict]] = []
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
for batch_datum in batch_datum_list:
node_field_values = [
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
# Zip the dicts together to create a list of dicts for each permutation
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
# We serialize the graph and session once, then mutate the graph dict in place for each session.
#
# This sounds scary, but it's actually fine.
#
# The batch prep logic injects field values into the same fields for each generated session.
#
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
# [
# (
# {"node_path": "1", "field_name": "a", "value": 1},
# {"node_path": "2", "field_name": "b", "value": 2},
# {"node_path": "3", "field_name": "c", "value": 3},
# ),
# (
# {"node_path": "1", "field_name": "a", "value": 4},
# {"node_path": "2", "field_name": "b", "value": 5},
# {"node_path": "3", "field_name": "c", "value": 6},
# )
# ]
#
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
# No matter the complexity of the batch, this property holds true.
#
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
#
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
# but this was also slow.
#
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
# objects for each session.
#
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
# dict as the session's graph.
# Dump the batch's graph to a dict once
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
# overwritten for each session by the mutated graph_as_dict.
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
count = 0
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
# still limited by the maximum number of sessions.
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
# We've reached the maximum number of sessions we may generate
return
# Flatten the list of lists of dicts into a single list of dicts
# TODO(psyche): Is the a more efficient way to do this?
flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
# Need a fresh ID for each session
session_id = uuid_string()
# Mutate the session dict in place
session_dict["id"] = session_id
# Substitute the values into the graph
for nfv in flat_node_field_values:
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
# Mutate the session dict in place
session_dict["graph"] = graph_as_dict
# Serialize the session and field values
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
session_json = json.dumps(session_dict, default=to_jsonable_python)
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
# Yield the session_id, session_json, and field_values_json
yield (session_id, session_json, field_values_json)
# Increment the count so we know when to stop
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
@@ -476,42 +558,75 @@ def calc_session_count(batch: Batch) -> int:
return len(data_product) * batch.runs
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
destination: str | None
retried_from_item_id: int | None = None
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)
str, # session_id
str, # batch_id
str | None, # field_values (optional, as stringified JSON)
int, # priority
str | None, # workflow (optional, as stringified JSON)
str | None, # origin (optional)
str | None, # destination (optional)
str | None, # retried_from_item_id (optional, this is always None for new items)
]
"""A type alias for the tuple of values to insert into the session queue table."""
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
`executemany` statement to insert multiple rows at once.
Args:
queue_id: The ID of the queue to insert the items into
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
- queue_id
- session (as stringified JSON)
- session_id
- batch_id
- field_values (optional, as stringified JSON)
- priority
- workflow (optional, as stringified JSON)
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
"""
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
# this difference becomes noticeable.
#
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
values_to_insert: list[ValueToInsertTuple] = []
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
# not support by default. Apparently there are sets somewhere in the graph.
# The same workflow is used for all sessions in the batch - serialize it once
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
values_to_insert.append(
SessionQueueValueToInsert(
queue_id=queue_id,
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
session_id=session.id,
batch_id=batch.batch_id,
# must use pydantic_encoder bc field_values is a list of models
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
priority=priority,
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
origin=batch.origin,
destination=batch.destination,
(
queue_id,
session_json,
session_id,
batch.batch_id,
field_values_json,
priority,
workflow_json,
batch.origin,
batch.destination,
None,
)
)
return values_to_insert

View File

@@ -27,7 +27,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
SessionQueueValueToInsert,
calc_session_count,
prepare_values_to_insert,
)
@@ -772,7 +771,7 @@ class SqliteSessionQueue(SessionQueueBase):
try:
self.__lock.acquire()
values_to_insert: list[SessionQueueValueToInsert] = []
values_to_insert: list[tuple] = []
retried_item_ids: list[int] = []
for item_id in item_ids:
@@ -798,17 +797,17 @@ class SqliteSessionQueue(SessionQueueBase):
else queue_item.item_id
)
value_to_insert = SessionQueueValueToInsert(
queue_id=queue_item.queue_id,
batch_id=queue_item.batch_id,
destination=queue_item.destination,
field_values=field_values_json,
origin=queue_item.origin,
priority=queue_item.priority,
workflow=workflow_json,
session=cloned_session_json,
session_id=cloned_session.id,
retried_from_item_id=retried_from_item_id,
value_to_insert = (
queue_item.queue_id,
queue_item.batch_id,
queue_item.destination,
field_values_json,
queue_item.origin,
queue_item.priority,
workflow_json,
cloned_session_json,
cloned_session.id,
retried_from_item_id,
)
values_to_insert.append(value_to_insert)

View File

@@ -1,3 +1,5 @@
import json
import pytest
from pydantic import TypeAdapter, ValidationError
@@ -44,46 +46,46 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph
# 2 list[BatchDatum] * length 2 * 2 runs = 8
assert len(t) == 8
assert t[0][0].graph.get_node("1").prompt == "Banana sushi"
assert t[0][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[0][0].graph.get_node("3").prompt == "Orange sushi"
assert t[0][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[0][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
assert json.loads(t[0][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
assert json.loads(t[0][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
assert json.loads(t[0][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[1][0].graph.get_node("1").prompt == "Banana sushi"
assert t[1][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[1][0].graph.get_node("3").prompt == "Apple sushi"
assert t[1][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[1][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
assert json.loads(t[1][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
assert json.loads(t[1][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
assert json.loads(t[1][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[2][0].graph.get_node("1").prompt == "Grape sushi"
assert t[2][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[2][0].graph.get_node("3").prompt == "Orange sushi"
assert t[2][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[2][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
assert json.loads(t[2][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
assert json.loads(t[2][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
assert json.loads(t[2][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[3][0].graph.get_node("1").prompt == "Grape sushi"
assert t[3][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[3][0].graph.get_node("3").prompt == "Apple sushi"
assert t[3][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[3][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
assert json.loads(t[3][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
assert json.loads(t[3][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
assert json.loads(t[3][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
# repeat for second run
assert t[4][0].graph.get_node("1").prompt == "Banana sushi"
assert t[4][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[4][0].graph.get_node("3").prompt == "Orange sushi"
assert t[4][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[4][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
assert json.loads(t[4][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
assert json.loads(t[4][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
assert json.loads(t[4][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[5][0].graph.get_node("1").prompt == "Banana sushi"
assert t[5][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[5][0].graph.get_node("3").prompt == "Apple sushi"
assert t[5][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[5][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
assert json.loads(t[5][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
assert json.loads(t[5][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
assert json.loads(t[5][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[6][0].graph.get_node("1").prompt == "Grape sushi"
assert t[6][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[6][0].graph.get_node("3").prompt == "Orange sushi"
assert t[6][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[6][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
assert json.loads(t[6][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
assert json.loads(t[6][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
assert json.loads(t[6][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
assert t[7][0].graph.get_node("1").prompt == "Grape sushi"
assert t[7][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[7][0].graph.get_node("3").prompt == "Apple sushi"
assert t[7][0].graph.get_node("4").prompt == "Nissan"
assert json.loads(t[7][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
assert json.loads(t[7][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
assert json.loads(t[7][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
assert json.loads(t[7][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
@@ -127,7 +129,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
ges = GraphExecutionStateValidator.validate_json(values[0].session)
ges = GraphExecutionStateValidator.validate_json(values[0][1])
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
@@ -136,26 +138,26 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
assert [v[2] for v in values] == [GraphExecutionStateValidator.validate_json(v[1]).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
sids = [v[2] for v in values]
assert len(sids) == len(set(sids))
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert isinstance(values[0].field_values, str)
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
assert isinstance(values[0][4], str)
assert len(NodeFieldValueValidator.validate_json(values[0][4])) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)
assert all(v.priority == 0 for v in values)
assert all(v[3] == b.batch_id for v in values)
assert all(v[5] == 0 for v in values)
def test_prepare_values_to_insert_with_priority(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=1000)
assert all(v.priority == 1 for v in values)
assert all(v[5] == 1 for v in values)
def test_prepare_values_to_insert_with_max(batch_data_collection, batch_graph):