feat: use ModelValidator naming convention for pydantic type adapters

This is the naming convention in the docs and is also clear.
This commit is contained in:
psychedelicious
2023-10-17 19:46:37 +11:00
parent 3c4f43314c
commit 4012388f0a
9 changed files with 36 additions and 36 deletions

View File

@@ -615,8 +615,8 @@ def test_graph_can_deserialize():
g.add_edge(e)
json = g.model_dump_json()
adapter_graph = TypeAdapter(Graph)
g2 = adapter_graph.validate_json(json)
GraphValidator = TypeAdapter(Graph)
g2 = GraphValidator.validate_json(json)
assert g2 is not None
assert g2.nodes["1"] is not None

View File

@@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
assert len(values) == 8
session_adapter = TypeAdapter(GraphExecutionState)
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
ges = session_adapter.validate_json(values[0].session)
ges = GraphExecutionStateValidator.validate_json(values[0].session)
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
@@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
assert len(sids) == len(set(sids))
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)