From c02be4bdf494335c7f661fafac1ac146187709d4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 30 Jun 2025 17:05:29 +1000 Subject: [PATCH] refactor(app): lean on pydantic to get field types in edge validation logic Previously we used python's own type introspection utilties to determine input and output field types. We can use pydantic to get the field types in a clearer, more direct way. This improvement also exposed an awkward behaviour in this utility, where it would return None when a field doesn't exist. I've added a comment in the code describing the issue, but changing it would require some significant changes and I don't want to risk breaking anything. --- invokeai/app/services/shared/graph.py | 33 +++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2706a0936a..5e83612017 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,7 +2,7 @@ import copy import itertools -from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin import networkx as nx from pydantic import ( @@ -58,17 +58,32 @@ class Edge(BaseModel): def get_output_field_type(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_annotation()) - node_output_field = node_outputs.get(field) or None - return node_output_field + # TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which + # really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this + # would require some fairly significant changes and I don't want risk breaking anything. + try: + invocation_class = type(node) + invocation_output_class = invocation_class.get_output_annotation() + field_info = invocation_output_class.model_fields.get(field) + assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}" + output_field_type = field_info.annotation + return output_field_type + except Exception: + return None def get_input_field_type(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_inputs = get_type_hints(node_type) - node_input_field = node_inputs.get(field) or None - return node_input_field + # TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which + # really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this + # would require some fairly significant changes and I don't want risk breaking anything. + try: + invocation_class = type(node) + field_info = invocation_class.model_fields.get(field) + assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}" + input_field_type = field_info.annotation + return input_field_type + except Exception: + return None def is_union_subtype(t1, t2):