mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 09:28:02 -05:00
Compare commits
4 Commits
controlnet
...
psyche/per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae8459f221 | ||
|
|
b73f15d4a2 | ||
|
|
5e1974d924 | ||
|
|
3011dfca16 |
@@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from itertools import chain, product
|
from itertools import chain, product
|
||||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
AliasChoices,
|
AliasChoices,
|
||||||
@@ -406,61 +406,143 @@ class IsFullResult(BaseModel):
|
|||||||
# region Util
|
# region Util
|
||||||
|
|
||||||
|
|
||||||
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
|
||||||
"""
|
"""
|
||||||
Populates the given graph with the given batch data items.
|
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
|
||||||
"""
|
field_values_json for each session.
|
||||||
graph_clone = graph.model_copy(deep=True)
|
|
||||||
for item in node_field_values:
|
|
||||||
node = graph_clone.get_node(item.node_path)
|
|
||||||
if node is None:
|
|
||||||
continue
|
|
||||||
setattr(node, item.field_name, item.value)
|
|
||||||
graph_clone.update_node(item.node_path, node)
|
|
||||||
return graph_clone
|
|
||||||
|
|
||||||
|
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
|
||||||
|
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
|
||||||
|
the field.
|
||||||
|
|
||||||
def create_session_nfv_tuples(
|
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
|
||||||
batch: Batch, maximum: int
|
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
|
||||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
|
||||||
"""
|
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
|
||||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
|
||||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
Each inner list now represents the substitution values for a single permutation (session).
|
||||||
that was applied to the graph.
|
- For each permutation, substitute the values into the graph
|
||||||
|
|
||||||
|
This function is optimized for performance, as it is used to generate a large number of sessions at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to generate sessions from
|
||||||
|
maximum: The maximum number of sessions to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
|
||||||
|
generator will stop early if the maximum number of sessions is reached.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Should this be a class method on Batch?
|
# TODO: Should this be a class method on Batch?
|
||||||
|
|
||||||
data: list[list[tuple[NodeFieldValue]]] = []
|
data: list[list[tuple[dict]]] = []
|
||||||
batch_data_collection = batch.data if batch.data is not None else []
|
batch_data_collection = batch.data if batch.data is not None else []
|
||||||
for batch_datum_list in batch_data_collection:
|
|
||||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
|
||||||
|
|
||||||
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
for batch_datum_list in batch_data_collection:
|
||||||
|
node_field_values_to_zip: list[list[dict]] = []
|
||||||
|
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
|
||||||
for batch_datum in batch_datum_list:
|
for batch_datum in batch_datum_list:
|
||||||
node_field_values = [
|
node_field_values = [
|
||||||
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
|
||||||
|
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
|
||||||
|
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
|
||||||
for item in batch_datum.items
|
for item in batch_datum.items
|
||||||
]
|
]
|
||||||
node_field_values_to_zip.append(node_field_values)
|
node_field_values_to_zip.append(node_field_values)
|
||||||
|
# Zip the dicts together to create a list of dicts for each permutation
|
||||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||||
|
|
||||||
# create generator to yield session,nfv tuples
|
# We serialize the graph and session once, then mutate the graph dict in place for each session.
|
||||||
|
#
|
||||||
|
# This sounds scary, but it's actually fine.
|
||||||
|
#
|
||||||
|
# The batch prep logic injects field values into the same fields for each generated session.
|
||||||
|
#
|
||||||
|
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
|
||||||
|
# [
|
||||||
|
# (
|
||||||
|
# {"node_path": "1", "field_name": "a", "value": 1},
|
||||||
|
# {"node_path": "2", "field_name": "b", "value": 2},
|
||||||
|
# {"node_path": "3", "field_name": "c", "value": 3},
|
||||||
|
# ),
|
||||||
|
# (
|
||||||
|
# {"node_path": "1", "field_name": "a", "value": 4},
|
||||||
|
# {"node_path": "2", "field_name": "b", "value": 5},
|
||||||
|
# {"node_path": "3", "field_name": "c", "value": 6},
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
|
||||||
|
# No matter the complexity of the batch, this property holds true.
|
||||||
|
#
|
||||||
|
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
|
||||||
|
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
|
||||||
|
#
|
||||||
|
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
|
||||||
|
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
|
||||||
|
# but this was also slow.
|
||||||
|
#
|
||||||
|
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
|
||||||
|
# objects for each session.
|
||||||
|
#
|
||||||
|
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
|
||||||
|
# dict as the session's graph.
|
||||||
|
|
||||||
|
# Dump the batch's graph to a dict once
|
||||||
|
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
|
||||||
|
|
||||||
|
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
|
||||||
|
# overwritten for each session by the mutated graph_as_dict.
|
||||||
|
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
|
||||||
|
|
||||||
|
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
|
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
|
||||||
|
# still limited by the maximum number of sessions.
|
||||||
for _ in range(batch.runs):
|
for _ in range(batch.runs):
|
||||||
for d in product(*data):
|
for d in product(*data):
|
||||||
if count >= maximum:
|
if count >= maximum:
|
||||||
|
# We've reached the maximum number of sessions we may generate
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Flatten the list of lists of dicts into a single list of dicts
|
||||||
|
# TODO(psyche): Is the a more efficient way to do this?
|
||||||
flat_node_field_values = list(chain.from_iterable(d))
|
flat_node_field_values = list(chain.from_iterable(d))
|
||||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
|
||||||
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
# Need a fresh ID for each session
|
||||||
|
session_id = uuid_string()
|
||||||
|
|
||||||
|
# Mutate the session dict in place
|
||||||
|
session_dict["id"] = session_id
|
||||||
|
|
||||||
|
# Substitute the values into the graph
|
||||||
|
for nfv in flat_node_field_values:
|
||||||
|
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
|
||||||
|
|
||||||
|
# Mutate the session dict in place
|
||||||
|
session_dict["graph"] = graph_as_dict
|
||||||
|
|
||||||
|
# Serialize the session and field values
|
||||||
|
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
|
||||||
|
session_json = json.dumps(session_dict, default=to_jsonable_python)
|
||||||
|
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
|
||||||
|
|
||||||
|
# Yield the session_id, session_json, and field_values_json
|
||||||
|
yield (session_id, session_json, field_values_json)
|
||||||
|
|
||||||
|
# Increment the count so we know when to stop
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
def calc_session_count(batch: Batch) -> int:
|
def calc_session_count(batch: Batch) -> int:
|
||||||
"""
|
"""
|
||||||
Calculates the number of sessions that would be created by the batch, without incurring
|
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||||
the overhead of actually generating them. Adapted from `create_sessions().
|
creating them, as is done in `create_session_nfv_tuples()`.
|
||||||
|
|
||||||
|
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||||
|
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||||
"""
|
"""
|
||||||
# TODO: Should this be a class method on Batch?
|
# TODO: Should this be a class method on Batch?
|
||||||
if not batch.data:
|
if not batch.data:
|
||||||
@@ -476,42 +558,75 @@ def calc_session_count(batch: Batch) -> int:
|
|||||||
return len(data_product) * batch.runs
|
return len(data_product) * batch.runs
|
||||||
|
|
||||||
|
|
||||||
class SessionQueueValueToInsert(NamedTuple):
|
ValueToInsertTuple: TypeAlias = tuple[
|
||||||
"""A tuple of values to insert into the session_queue table"""
|
str, # queue_id
|
||||||
|
str, # session (as stringified JSON)
|
||||||
# Careful with the ordering of this - it must match the insert statement
|
str, # session_id
|
||||||
queue_id: str # queue_id
|
str, # batch_id
|
||||||
session: str # session json
|
str | None, # field_values (optional, as stringified JSON)
|
||||||
session_id: str # session_id
|
int, # priority
|
||||||
batch_id: str # batch_id
|
str | None, # workflow (optional, as stringified JSON)
|
||||||
field_values: Optional[str] # field_values json
|
str | None, # origin (optional)
|
||||||
priority: int # priority
|
str | None, # destination (optional)
|
||||||
workflow: Optional[str] # workflow json
|
str | None, # retried_from_item_id (optional, this is always None for new items)
|
||||||
origin: str | None
|
]
|
||||||
destination: str | None
|
"""A type alias for the tuple of values to insert into the session queue table."""
|
||||||
retried_from_item_id: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
def prepare_values_to_insert(
|
||||||
|
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
|
||||||
|
) -> list[ValueToInsertTuple]:
|
||||||
|
"""
|
||||||
|
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
|
||||||
|
`executemany` statement to insert multiple rows at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_id: The ID of the queue to insert the items into
|
||||||
|
batch: The batch to prepare the values for
|
||||||
|
priority: The priority of the queue items
|
||||||
|
max_new_queue_items: The maximum number of queue items to insert
|
||||||
|
|
||||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
Returns:
|
||||||
values_to_insert: ValuesToInsert = []
|
A list of tuples to insert into the session queue table. Each tuple contains the following values:
|
||||||
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
- queue_id
|
||||||
# sessions must have unique id
|
- session (as stringified JSON)
|
||||||
session.id = uuid_string()
|
- session_id
|
||||||
|
- batch_id
|
||||||
|
- field_values (optional, as stringified JSON)
|
||||||
|
- priority
|
||||||
|
- workflow (optional, as stringified JSON)
|
||||||
|
- origin (optional)
|
||||||
|
- destination (optional)
|
||||||
|
- retried_from_item_id (optional, this is always None for new items)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
|
||||||
|
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
|
||||||
|
# this difference becomes noticeable.
|
||||||
|
#
|
||||||
|
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
|
||||||
|
|
||||||
|
values_to_insert: list[ValueToInsertTuple] = []
|
||||||
|
|
||||||
|
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
|
||||||
|
# not support by default. Apparently there are sets somewhere in the graph.
|
||||||
|
|
||||||
|
# The same workflow is used for all sessions in the batch - serialize it once
|
||||||
|
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
|
||||||
|
|
||||||
|
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||||
values_to_insert.append(
|
values_to_insert.append(
|
||||||
SessionQueueValueToInsert(
|
(
|
||||||
queue_id=queue_id,
|
queue_id,
|
||||||
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
|
session_json,
|
||||||
session_id=session.id,
|
session_id,
|
||||||
batch_id=batch.batch_id,
|
batch.batch_id,
|
||||||
# must use pydantic_encoder bc field_values is a list of models
|
field_values_json,
|
||||||
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
|
priority,
|
||||||
priority=priority,
|
workflow_json,
|
||||||
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
|
batch.origin,
|
||||||
origin=batch.origin,
|
batch.destination,
|
||||||
destination=batch.destination,
|
None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return values_to_insert
|
return values_to_insert
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueItemNotFoundError,
|
SessionQueueItemNotFoundError,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
SessionQueueValueToInsert,
|
|
||||||
calc_session_count,
|
calc_session_count,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
@@ -772,7 +771,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
try:
|
try:
|
||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
|
|
||||||
values_to_insert: list[SessionQueueValueToInsert] = []
|
values_to_insert: list[tuple] = []
|
||||||
retried_item_ids: list[int] = []
|
retried_item_ids: list[int] = []
|
||||||
|
|
||||||
for item_id in item_ids:
|
for item_id in item_ids:
|
||||||
@@ -798,17 +797,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
else queue_item.item_id
|
else queue_item.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
value_to_insert = SessionQueueValueToInsert(
|
value_to_insert = (
|
||||||
queue_id=queue_item.queue_id,
|
queue_item.queue_id,
|
||||||
batch_id=queue_item.batch_id,
|
queue_item.batch_id,
|
||||||
destination=queue_item.destination,
|
queue_item.destination,
|
||||||
field_values=field_values_json,
|
field_values_json,
|
||||||
origin=queue_item.origin,
|
queue_item.origin,
|
||||||
priority=queue_item.priority,
|
queue_item.priority,
|
||||||
workflow=workflow_json,
|
workflow_json,
|
||||||
session=cloned_session_json,
|
cloned_session_json,
|
||||||
session_id=cloned_session.id,
|
cloned_session.id,
|
||||||
retried_from_item_id=retried_from_item_id,
|
retried_from_item_id,
|
||||||
)
|
)
|
||||||
values_to_insert.append(value_to_insert)
|
values_to_insert.append(value_to_insert)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
@@ -44,46 +46,46 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph
|
|||||||
# 2 list[BatchDatum] * length 2 * 2 runs = 8
|
# 2 list[BatchDatum] * length 2 * 2 runs = 8
|
||||||
assert len(t) == 8
|
assert len(t) == 8
|
||||||
|
|
||||||
assert t[0][0].graph.get_node("1").prompt == "Banana sushi"
|
assert json.loads(t[0][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||||
assert t[0][0].graph.get_node("2").prompt == "Strawberry sushi"
|
assert json.loads(t[0][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||||
assert t[0][0].graph.get_node("3").prompt == "Orange sushi"
|
assert json.loads(t[0][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||||
assert t[0][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[0][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[1][0].graph.get_node("1").prompt == "Banana sushi"
|
assert json.loads(t[1][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||||
assert t[1][0].graph.get_node("2").prompt == "Strawberry sushi"
|
assert json.loads(t[1][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||||
assert t[1][0].graph.get_node("3").prompt == "Apple sushi"
|
assert json.loads(t[1][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||||
assert t[1][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[1][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[2][0].graph.get_node("1").prompt == "Grape sushi"
|
assert json.loads(t[2][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||||
assert t[2][0].graph.get_node("2").prompt == "Blueberry sushi"
|
assert json.loads(t[2][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||||
assert t[2][0].graph.get_node("3").prompt == "Orange sushi"
|
assert json.loads(t[2][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||||
assert t[2][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[2][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[3][0].graph.get_node("1").prompt == "Grape sushi"
|
assert json.loads(t[3][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||||
assert t[3][0].graph.get_node("2").prompt == "Blueberry sushi"
|
assert json.loads(t[3][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||||
assert t[3][0].graph.get_node("3").prompt == "Apple sushi"
|
assert json.loads(t[3][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||||
assert t[3][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[3][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
# repeat for second run
|
# repeat for second run
|
||||||
assert t[4][0].graph.get_node("1").prompt == "Banana sushi"
|
assert json.loads(t[4][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||||
assert t[4][0].graph.get_node("2").prompt == "Strawberry sushi"
|
assert json.loads(t[4][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||||
assert t[4][0].graph.get_node("3").prompt == "Orange sushi"
|
assert json.loads(t[4][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||||
assert t[4][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[4][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[5][0].graph.get_node("1").prompt == "Banana sushi"
|
assert json.loads(t[5][1])["graph"]["nodes"]["1"]["prompt"] == "Banana sushi"
|
||||||
assert t[5][0].graph.get_node("2").prompt == "Strawberry sushi"
|
assert json.loads(t[5][1])["graph"]["nodes"]["2"]["prompt"] == "Strawberry sushi"
|
||||||
assert t[5][0].graph.get_node("3").prompt == "Apple sushi"
|
assert json.loads(t[5][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||||
assert t[5][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[5][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[6][0].graph.get_node("1").prompt == "Grape sushi"
|
assert json.loads(t[6][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||||
assert t[6][0].graph.get_node("2").prompt == "Blueberry sushi"
|
assert json.loads(t[6][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||||
assert t[6][0].graph.get_node("3").prompt == "Orange sushi"
|
assert json.loads(t[6][1])["graph"]["nodes"]["3"]["prompt"] == "Orange sushi"
|
||||||
assert t[6][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[6][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
assert t[7][0].graph.get_node("1").prompt == "Grape sushi"
|
assert json.loads(t[7][1])["graph"]["nodes"]["1"]["prompt"] == "Grape sushi"
|
||||||
assert t[7][0].graph.get_node("2").prompt == "Blueberry sushi"
|
assert json.loads(t[7][1])["graph"]["nodes"]["2"]["prompt"] == "Blueberry sushi"
|
||||||
assert t[7][0].graph.get_node("3").prompt == "Apple sushi"
|
assert json.loads(t[7][1])["graph"]["nodes"]["3"]["prompt"] == "Apple sushi"
|
||||||
assert t[7][0].graph.get_node("4").prompt == "Nissan"
|
assert json.loads(t[7][1])["graph"]["nodes"]["4"]["prompt"] == "Nissan"
|
||||||
|
|
||||||
|
|
||||||
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
|
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
|
||||||
@@ -127,7 +129,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
|
|
||||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||||
# graph should be serialized
|
# graph should be serialized
|
||||||
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
ges = GraphExecutionStateValidator.validate_json(values[0][1])
|
||||||
|
|
||||||
# graph values should be populated
|
# graph values should be populated
|
||||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||||
@@ -136,26 +138,26 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||||
|
|
||||||
# session ids should match deserialized graph
|
# session ids should match deserialized graph
|
||||||
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
assert [v[2] for v in values] == [GraphExecutionStateValidator.validate_json(v[1]).id for v in values]
|
||||||
|
|
||||||
# should unique session ids
|
# should unique session ids
|
||||||
sids = [v.session_id for v in values]
|
sids = [v[2] for v in values]
|
||||||
assert len(sids) == len(set(sids))
|
assert len(sids) == len(set(sids))
|
||||||
|
|
||||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||||
# should have 3 node field values
|
# should have 3 node field values
|
||||||
assert isinstance(values[0].field_values, str)
|
assert isinstance(values[0][4], str)
|
||||||
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
assert len(NodeFieldValueValidator.validate_json(values[0][4])) == 3
|
||||||
|
|
||||||
# should have batch id and priority
|
# should have batch id and priority
|
||||||
assert all(v.batch_id == b.batch_id for v in values)
|
assert all(v[3] == b.batch_id for v in values)
|
||||||
assert all(v.priority == 0 for v in values)
|
assert all(v[5] == 0 for v in values)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_values_to_insert_with_priority(batch_data_collection, batch_graph):
|
def test_prepare_values_to_insert_with_priority(batch_data_collection, batch_graph):
|
||||||
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
||||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=1000)
|
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=1000)
|
||||||
assert all(v.priority == 1 for v in values)
|
assert all(v[5] == 1 for v in values)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_values_to_insert_with_max(batch_data_collection, batch_graph):
|
def test_prepare_values_to_insert_with_max(batch_data_collection, batch_graph):
|
||||||
|
|||||||
Reference in New Issue
Block a user