mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat: use ModelValidator naming convention for pydantic type adapters
This is the naming convention in the docs and is also clear.
This commit is contained in:
@@ -615,8 +615,8 @@ def test_graph_can_deserialize():
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.model_dump_json()
|
||||
adapter_graph = TypeAdapter(Graph)
|
||||
g2 = adapter_graph.validate_json(json)
|
||||
GraphValidator = TypeAdapter(Graph)
|
||||
g2 = GraphValidator.validate_json(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
|
||||
@@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||
assert len(values) == 8
|
||||
|
||||
session_adapter = TypeAdapter(GraphExecutionState)
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
# graph should be serialized
|
||||
ges = session_adapter.validate_json(values[0].session)
|
||||
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||
|
||||
# graph values should be populated
|
||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||
@@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||
|
||||
# session ids should match deserialized graph
|
||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
||||
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||
|
||||
# should unique session ids
|
||||
sids = [v.session_id for v in values]
|
||||
assert len(sids) == len(set(sids))
|
||||
|
||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
# should have 3 node field values
|
||||
assert type(values[0].field_values) is str
|
||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
||||
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||
|
||||
# should have batch id and priority
|
||||
assert all(v.batch_id == b.batch_id for v in values)
|
||||
|
||||
Reference in New Issue
Block a user