mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 23:58:03 -05:00
Compare commits
4 Commits
bria-UI
...
psyche/per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae8459f221 | ||
|
|
b73f15d4a2 | ||
|
|
5e1974d924 | ||
|
|
3011dfca16 |
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@@ -406,61 +406,143 @@ class IsFullResult(BaseModel):
|
||||
# region Util
|
||||
|
||||
|
||||
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
||||
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
|
||||
"""
|
||||
Populates the given graph with the given batch data items.
|
||||
"""
|
||||
graph_clone = graph.model_copy(deep=True)
|
||||
for item in node_field_values:
|
||||
node = graph_clone.get_node(item.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
setattr(node, item.field_name, item.value)
|
||||
graph_clone.update_node(item.node_path, node)
|
||||
return graph_clone
|
||||
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
|
||||
field_values_json for each session.
|
||||
|
||||
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
|
||||
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
|
||||
the field.
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
that was applied to the graph.
|
||||
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
|
||||
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
|
||||
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
|
||||
Each inner list now represents the substitution values for a single permutation (session).
|
||||
- For each permutation, substitute the values into the graph
|
||||
|
||||
This function is optimized for performance, as it is used to generate a large number of sessions at once.
|
||||
|
||||
Args:
|
||||
batch: The batch to generate sessions from
|
||||
maximum: The maximum number of sessions to generate
|
||||
|
||||
Returns:
|
||||
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
|
||||
generator will stop early if the maximum number of sessions is reached.
|
||||
"""
|
||||
|
||||
# TODO: Should this be a class method on Batch?
|
||||
|
||||
data: list[list[tuple[NodeFieldValue]]] = []
|
||||
data: list[list[tuple[dict]]] = []
|
||||
batch_data_collection = batch.data if batch.data is not None else []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||
|
||||
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
node_field_values_to_zip: list[list[dict]] = []
|
||||
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
|
||||
for batch_datum in batch_datum_list:
|
||||
node_field_values = [
|
||||
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
||||
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
|
||||
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
|
||||
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
# Zip the dicts together to create a list of dicts for each permutation
|
||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
# We serialize the graph and session once, then mutate the graph dict in place for each session.
|
||||
#
|
||||
# This sounds scary, but it's actually fine.
|
||||
#
|
||||
# The batch prep logic injects field values into the same fields for each generated session.
|
||||
#
|
||||
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
|
||||
# [
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 1},
|
||||
# {"node_path": "2", "field_name": "b", "value": 2},
|
||||
# {"node_path": "3", "field_name": "c", "value": 3},
|
||||
# ),
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 4},
|
||||
# {"node_path": "2", "field_name": "b", "value": 5},
|
||||
# {"node_path": "3", "field_name": "c", "value": 6},
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
|
||||
# No matter the complexity of the batch, this property holds true.
|
||||
#
|
||||
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
|
||||
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
|
||||
#
|
||||
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
|
||||
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
|
||||
# but this was also slow.
|
||||
#
|
||||
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
|
||||
# objects for each session.
|
||||
#
|
||||
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
|
||||
# dict as the session's graph.
|
||||
|
||||
# Dump the batch's graph to a dict once
|
||||
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
|
||||
# overwritten for each session by the mutated graph_as_dict.
|
||||
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
|
||||
count = 0
|
||||
|
||||
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
|
||||
# still limited by the maximum number of sessions.
|
||||
for _ in range(batch.runs):
|
||||
for d in product(*data):
|
||||
if count >= maximum:
|
||||
# We've reached the maximum number of sessions we may generate
|
||||
return
|
||||
|
||||
# Flatten the list of lists of dicts into a single list of dicts
|
||||
# TODO(psyche): Is the a more efficient way to do this?
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
||||
|
||||
# Need a fresh ID for each session
|
||||
session_id = uuid_string()
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["id"] = session_id
|
||||
|
||||
# Substitute the values into the graph
|
||||
for nfv in flat_node_field_values:
|
||||
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["graph"] = graph_as_dict
|
||||
|
||||
# Serialize the session and field values
|
||||
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
|
||||
session_json = json.dumps(session_dict, default=to_jsonable_python)
|
||||
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
|
||||
|
||||
# Yield the session_id, session_json, and field_values_json
|
||||
yield (session_id, session_json, field_values_json)
|
||||
|
||||
# Increment the count so we know when to stop
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring
|
||||
the overhead of actually generating them. Adapted from `create_sessions().
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
@@ -476,42 +558,75 @@ def calc_session_count(batch: Batch) -> int:
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
|
||||
# Careful with the ordering of this - it must match the insert statement
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
retried_from_item_id: int | None = None
|
||||
ValueToInsertTuple: TypeAlias = tuple[
|
||||
str, # queue_id
|
||||
str, # session (as stringified JSON)
|
||||
str, # session_id
|
||||
str, # batch_id
|
||||
str | None, # field_values (optional, as stringified JSON)
|
||||
int, # priority
|
||||
str | None, # workflow (optional, as stringified JSON)
|
||||
str | None, # origin (optional)
|
||||
str | None, # destination (optional)
|
||||
str | None, # retried_from_item_id (optional, this is always None for new items)
|
||||
]
|
||||
"""A type alias for the tuple of values to insert into the session queue table."""
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
def prepare_values_to_insert(
|
||||
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
|
||||
) -> list[ValueToInsertTuple]:
|
||||
"""
|
||||
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
|
||||
`executemany` statement to insert multiple rows at once.
|
||||
|
||||
Args:
|
||||
queue_id: The ID of the queue to insert the items into
|
||||
batch: The batch to prepare the values for
|
||||
priority: The priority of the queue items
|
||||
max_new_queue_items: The maximum number of queue items to insert
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
Returns:
|
||||
A list of tuples to insert into the session queue table. Each tuple contains the following values:
|
||||
- queue_id
|
||||
- session (as stringified JSON)
|
||||
- session_id
|
||||
- batch_id
|
||||
- field_values (optional, as stringified JSON)
|
||||
- priority
|
||||
- workflow (optional, as stringified JSON)
|
||||
- origin (optional)
|
||||
- destination (optional)
|
||||
- retried_from_item_id (optional, this is always None for new items)
|
||||
"""
|
||||
|
||||
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
|
||||
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
|
||||
# this difference becomes noticeable.
|
||||
#
|
||||
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
|
||||
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
|
||||
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
|
||||
# not support by default. Apparently there are sets somewhere in the graph.
|
||||
|
||||
# The same workflow is used for all sessions in the batch - serialize it once
|
||||
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
|
||||
|
||||
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
values_to_insert.append(
|
||||
SessionQueueValueToInsert(
|
||||
queue_id=queue_id,
|
||||
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
|
||||
session_id=session.id,
|
||||
batch_id=batch.batch_id,
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
|
||||
priority=priority,
|
||||
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
|
||||
origin=batch.origin,
|
||||
destination=batch.destination,
|
||||
(
|
||||
queue_id,
|
||||
session_json,
|
||||
session_id,
|
||||
batch.batch_id,
|
||||
field_values_json,
|
||||
priority,
|
||||
workflow_json,
|
||||
batch.origin,
|
||||
batch.destination,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -27,7 +27,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
SessionQueueValueToInsert,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
@@ -772,7 +771,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
values_to_insert: list[SessionQueueValueToInsert] = []
|
||||
values_to_insert: list[tuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
for item_id in item_ids:
|
||||
@@ -798,17 +797,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
else queue_item.item_id
|
||||
)
|
||||
|
||||
value_to_insert = SessionQueueValueToInsert(
|
||||
queue_id=queue_item.queue_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
destination=queue_item.destination,
|
||||
field_values=field_values_json,
|
||||
origin=queue_item.origin,
|
||||
priority=queue_item.priority,
|
||||
workflow=workflow_json,
|
||||
session=cloned_session_json,
|
||||
session_id=cloned_session.id,
|
||||
retried_from_item_id=retried_from_item_id,
|
||||
value_to_insert = (
|
||||
queue_item.queue_id,
|
||||
queue_item.batch_id,
|
||||
queue_item.destination,
|
||||
field_values_json,
|
||||
queue_item.origin,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
cloned_session_json,
|
||||
cloned_session.id,
|
||||
retried_from_item_id,
|
||||
)
|
||||
values_to_insert.append(value_to_insert)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
@@ -44,46 +46,46 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph
|
||||
# 2 list[BatchDatum] * length 2 * 2 runs = 8
|
||||
assert len(t) == 8
|
||||
|
||||
assert t[0][0].graph.get_node("1").prompt == "Banana sushi"
|
||||
assert t[0][0].graph.get_node("2").prompt == "Strawberry sushi"
|
||||
assert t[0][0].graph.get_node("3").prompt == "Orange sushi"
|
||||
assert t[0][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[0][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||
assert json.loads(t[0][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||
assert json.loads(t[0][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||
assert json.loads(t[0][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[1][0].graph.get_node("1").prompt == "Banana sushi"
|
||||
assert t[1][0].graph.get_node("2").prompt == "Strawberry sushi"
|
||||
assert t[1][0].graph.get_node("3").prompt == "Apple sushi"
|
||||
assert t[1][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[1][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||
assert json.loads(t[1][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||
assert json.loads(t[1][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||
assert json.loads(t[1][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[2][0].graph.get_node("1").prompt == "Grape sushi"
|
||||
assert t[2][0].graph.get_node("2").prompt == "Blueberry sushi"
|
||||
assert t[2][0].graph.get_node("3").prompt == "Orange sushi"
|
||||
assert t[2][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[2][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||
assert json.loads(t[2][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||
assert json.loads(t[2][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||
assert json.loads(t[2][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[3][0].graph.get_node("1").prompt == "Grape sushi"
|
||||
assert t[3][0].graph.get_node("2").prompt == "Blueberry sushi"
|
||||
assert t[3][0].graph.get_node("3").prompt == "Apple sushi"
|
||||
assert t[3][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[3][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||
assert json.loads(t[3][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||
assert json.loads(t[3][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||
assert json.loads(t[3][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
# repeat for second run
|
||||
assert t[4][0].graph.get_node("1").prompt == "Banana sushi"
|
||||
assert t[4][0].graph.get_node("2").prompt == "Strawberry sushi"
|
||||
assert t[4][0].graph.get_node("3").prompt == "Orange sushi"
|
||||
assert t[4][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[4][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||
assert json.loads(t[4][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||
assert json.loads(t[4][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||
assert json.loads(t[4][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[5][0].graph.get_node("1").prompt == "Banana sushi"
|
||||
assert t[5][0].graph.get_node("2").prompt == "Strawberry sushi"
|
||||
assert t[5][0].graph.get_node("3").prompt == "Apple sushi"
|
||||
assert t[5][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[5][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||
assert json.loads(t[5][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||
assert json.loads(t[5][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||
assert json.loads(t[5][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[6][0].graph.get_node("1").prompt == "Grape sushi"
|
||||
assert t[6][0].graph.get_node("2").prompt == "Blueberry sushi"
|
||||
assert t[6][0].graph.get_node("3").prompt == "Orange sushi"
|
||||
assert t[6][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[6][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||
assert json.loads(t[6][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||
assert json.loads(t[6][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||
assert json.loads(t[6][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
assert t[7][0].graph.get_node("1").prompt == "Grape sushi"
|
||||
assert t[7][0].graph.get_node("2").prompt == "Blueberry sushi"
|
||||
assert t[7][0].graph.get_node("3").prompt == "Apple sushi"
|
||||
assert t[7][0].graph.get_node("4").prompt == "Nissan"
|
||||
assert json.loads(t[7][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||
assert json.loads(t[7][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||
assert json.loads(t[7][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||
assert json.loads(t[7][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
|
||||
@@ -127,7 +129,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
# graph should be serialized
|
||||
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||
ges = GraphExecutionStateValidator.validate_json(values[0][1])
|
||||
|
||||
# graph values should be populated
|
||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||
@@ -136,26 +138,26 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||
|
||||
# session ids should match deserialized graph
|
||||
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||
assert [v[2] for v in values] == [GraphExecutionStateValidator.validate_json(v[1]).id for v in values]
|
||||
|
||||
# should unique session ids
|
||||
sids = [v.session_id for v in values]
|
||||
sids = [v[2] for v in values]
|
||||
assert len(sids) == len(set(sids))
|
||||
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
# should have 3 node field values
|
||||
assert isinstance(values[0].field_values, str)
|
||||
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||
assert isinstance(values[0][4], str)
|
||||
assert len(NodeFieldValueValidator.validate_json(values[0][4])) == 3
|
||||
|
||||
# should have batch id and priority
|
||||
assert all(v.batch_id == b.batch_id for v in values)
|
||||
assert all(v.priority == 0 for v in values)
|
||||
assert all(v[3] == b.batch_id for v in values)
|
||||
assert all(v[5] == 0 for v in values)
|
||||
|
||||
|
||||
def test_prepare_values_to_insert_with_priority(batch_data_collection, batch_graph):
|
||||
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=1000)
|
||||
assert all(v.priority == 1 for v in values)
|
||||
assert all(v[5] == 1 for v in values)
|
||||
|
||||
|
||||
def test_prepare_values_to_insert_with_max(batch_data_collection, batch_graph):
|
||||
|
||||
Reference in New Issue
Block a user