diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2706a0936a..5e83612017 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,7 +2,7 @@ import copy import itertools -from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin import networkx as nx from pydantic import ( @@ -58,17 +58,32 @@ class Edge(BaseModel): def get_output_field_type(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_annotation()) - node_output_field = node_outputs.get(field) or None - return node_output_field + # TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which + # really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this + # would require some fairly significant changes and I don't want risk breaking anything. + try: + invocation_class = type(node) + invocation_output_class = invocation_class.get_output_annotation() + field_info = invocation_output_class.model_fields.get(field) + assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}" + output_field_type = field_info.annotation + return output_field_type + except Exception: + return None def get_input_field_type(node: BaseInvocation, field: str) -> Any: - node_type = type(node) - node_inputs = get_type_hints(node_type) - node_input_field = node_inputs.get(field) or None - return node_input_field + # TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which + # really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this + # would require some fairly significant changes and I don't want risk breaking anything. + try: + invocation_class = type(node) + field_info = invocation_class.model_fields.get(field) + assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}" + input_field_type = field_info.annotation + return input_field_type + except Exception: + return None def is_union_subtype(t1, t2):