From df81f3274a191de8034c09ca6696e0e6063ed522 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 2 May 2025 07:43:48 +1000 Subject: [PATCH] feat(nodes): improved pydantic type annotation massaging When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`. This results in an error when doing a ser/de round trip. Here's what we end up doing: ```py from pydantic import BaseModel, Field class MyModel(BaseModel): foo: str = Field(default=None) ``` And here is a simple round-trip, which should not error but which does: ```py MyModel(**MyModel().model_dump()) # ValidationError: 1 validation error for MyModel # foo # Input should be a valid string [type=string_type, input_value=None, input_type=NoneType] # For further information visit https://errors.pydantic.dev/2.11/v/string_type ``` To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation ` | None`. This prevents the error during deserialization. This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as ` | None`, reflecting the overrides. This means the autogenerated types for fields have also changed for fields without defaults: ```ts // Old image?: components["schemas"]["ImageField"]; // New image?: components["schemas"]["ImageField"] | null; ``` This does not break anything on the frontend. --- invokeai/app/invocations/baseinvocation.py | 76 ++++++++++++++++------ tests/app/invocations/test_is_optional.py | 46 +++++++++++++ 2 files changed, 102 insertions(+), 20 deletions(-) create mode 100644 tests/app/invocations/test_is_optional.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 61e76a288e..02d4537c89 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -5,6 +5,8 @@ from __future__ import annotations import inspect import re import sys +import types +import typing import warnings from abc import ABC, abstractmethod from enum import Enum @@ -489,6 +491,18 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None return None +def is_optional(annotation: Any) -> bool: + """ + Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None). + """ + origin = typing.get_origin(annotation) + # PEP 604 unions (int|None) have origin types.UnionType + is_union = origin is typing.Union or origin is types.UnionType + if not is_union: + return False + return any(arg is type(None) for arg in typing.get_args(annotation)) + + def invocation( invocation_type: str, title: Optional[str] = None, @@ -523,6 +537,18 @@ def invocation( validate_fields(cls.model_fields, invocation_type) + fields: dict[str, tuple[Any, FieldInfo]] = {} + + for field_name, field_info in cls.model_fields.items(): + annotation = field_info.annotation + assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation." + assert isinstance(field_info.json_schema_extra, dict), ( + f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?" + ) + if field_info.default is None and not is_optional(annotation): + annotation = annotation | None + fields[field_name] = (annotation, field_info) + # Add OpenAPI schema extras uiconfig: dict[str, Any] = {} uiconfig["title"] = title @@ -557,11 +583,17 @@ def invocation( # Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does # not work. Instead, we have to create a new class with the type field and patch the original class with it. - invocation_type_annotation = Literal[invocation_type] # type: ignore + invocation_type_annotation = Literal[invocation_type] invocation_type_field = Field( title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) + # pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers + # don't get confused by something like this: + # foo: str = Field() <-- this is a FieldInfo, not a str + # Unfortunately this means we need to use type: ignore here to avoid type-checker errors + fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore + # Validate the `invoke()` method is implemented if "invoke" in cls.__abstractmethods__: raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method') @@ -583,17 +615,12 @@ def invocation( ) docstring = cls.__doc__ - cls = create_model( - cls.__qualname__, - __base__=cls, - __module__=cls.__module__, - type=(invocation_type_annotation, invocation_type_field), - ) - cls.__doc__ = docstring + new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) + new_class.__doc__ = docstring - InvocationRegistry.register_invocation(cls) + InvocationRegistry.register_invocation(new_class) - return cls + return new_class return wrapper @@ -618,23 +645,32 @@ def invocation_output( validate_fields(cls.model_fields, output_type) + fields: dict[str, tuple[Any, FieldInfo]] = {} + + for field_name, field_info in cls.model_fields.items(): + annotation = field_info.annotation + assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation." + assert isinstance(field_info.json_schema_extra, dict), ( + f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?" + ) + if field_info.default is not PydanticUndefined and is_optional(annotation): + annotation = annotation | None + fields[field_name] = (annotation, field_info) + # Add the output type to the model. - output_type_annotation = Literal[output_type] # type: ignore + output_type_annotation = Literal[output_type] output_type_field = Field( title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) + fields["type"] = (output_type_annotation, output_type_field) # type: ignore + docstring = cls.__doc__ - cls = create_model( - cls.__qualname__, - __base__=cls, - __module__=cls.__module__, - type=(output_type_annotation, output_type_field), - ) - cls.__doc__ = docstring + new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) + new_class.__doc__ = docstring - InvocationRegistry.register_output(cls) + InvocationRegistry.register_output(new_class) - return cls + return new_class return wrapper diff --git a/tests/app/invocations/test_is_optional.py b/tests/app/invocations/test_is_optional.py new file mode 100644 index 0000000000..58e2723ba9 --- /dev/null +++ b/tests/app/invocations/test_is_optional.py @@ -0,0 +1,46 @@ +from typing import Any, Literal, Optional, Union + +import pytest +from pydantic import BaseModel + + +class TestModel(BaseModel): + foo: Literal["bar"] = "bar" + + +@pytest.mark.parametrize( + "input_type, expected", + [ + (str, False), + (list[str], False), + (list[dict[str, Any]], False), + (list[None], False), + (list[dict[str, None]], False), + (Any, False), + (True, False), + (False, False), + (Union[str, False], False), + (Union[str, True], False), + (None, False), + (str | None, True), + (Union[str, None], True), + (Optional[str], True), + (str | int | None, True), + (None | str | int, True), + (Union[None, str], True), + (Optional[str], True), + (Optional[int], True), + (Optional[str], True), + (TestModel | None, True), + (Union[TestModel, None], True), + (Optional[TestModel], True), + ], +) +def test_is_optional(input_type: Any, expected: bool) -> None: + """ + Test the is_optional function. + """ + from invokeai.app.invocations.baseinvocation import is_optional + + result = is_optional(input_type) + assert result == expected, f"Expected {expected} but got {result} for input type {input_type}"