refactor(app): lean on pydantic to get field types in edge validation logic

Previously we used python's own type introspection utilties to determine
input and output field types. We can use pydantic to get the field types
in a clearer, more direct way.

This improvement also exposed an awkward behaviour in this utility,
where it would return None when a field doesn't exist. I've added a
comment in the code describing the issue, but changing it would require
some significant changes and I don't want to risk breaking anything.
This commit is contained in:
psychedelicious
2025-06-30 17:05:29 +10:00
parent ed7772d993
commit c02be4bdf4

View File

@@ -2,7 +2,7 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
import networkx as nx
from pydantic import (
@@ -58,17 +58,32 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
def is_union_subtype(t1, t2):