diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 61e76a288e..02d4537c89 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -5,6 +5,8 @@ from __future__ import annotations import inspect import re import sys +import types +import typing import warnings from abc import ABC, abstractmethod from enum import Enum @@ -489,6 +491,18 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None return None +def is_optional(annotation: Any) -> bool: + """ + Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None). + """ + origin = typing.get_origin(annotation) + # PEP 604 unions (int|None) have origin types.UnionType + is_union = origin is typing.Union or origin is types.UnionType + if not is_union: + return False + return any(arg is type(None) for arg in typing.get_args(annotation)) + + def invocation( invocation_type: str, title: Optional[str] = None, @@ -523,6 +537,18 @@ def invocation( validate_fields(cls.model_fields, invocation_type) + fields: dict[str, tuple[Any, FieldInfo]] = {} + + for field_name, field_info in cls.model_fields.items(): + annotation = field_info.annotation + assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation." + assert isinstance(field_info.json_schema_extra, dict), ( + f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?" + ) + if field_info.default is None and not is_optional(annotation): + annotation = annotation | None + fields[field_name] = (annotation, field_info) + # Add OpenAPI schema extras uiconfig: dict[str, Any] = {} uiconfig["title"] = title @@ -557,11 +583,17 @@ def invocation( # Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does # not work. Instead, we have to create a new class with the type field and patch the original class with it. - invocation_type_annotation = Literal[invocation_type] # type: ignore + invocation_type_annotation = Literal[invocation_type] invocation_type_field = Field( title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) + # pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers + # don't get confused by something like this: + # foo: str = Field() <-- this is a FieldInfo, not a str + # Unfortunately this means we need to use type: ignore here to avoid type-checker errors + fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore + # Validate the `invoke()` method is implemented if "invoke" in cls.__abstractmethods__: raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method') @@ -583,17 +615,12 @@ def invocation( ) docstring = cls.__doc__ - cls = create_model( - cls.__qualname__, - __base__=cls, - __module__=cls.__module__, - type=(invocation_type_annotation, invocation_type_field), - ) - cls.__doc__ = docstring + new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) + new_class.__doc__ = docstring - InvocationRegistry.register_invocation(cls) + InvocationRegistry.register_invocation(new_class) - return cls + return new_class return wrapper @@ -618,23 +645,32 @@ def invocation_output( validate_fields(cls.model_fields, output_type) + fields: dict[str, tuple[Any, FieldInfo]] = {} + + for field_name, field_info in cls.model_fields.items(): + annotation = field_info.annotation + assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation." + assert isinstance(field_info.json_schema_extra, dict), ( + f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?" + ) + if field_info.default is not PydanticUndefined and is_optional(annotation): + annotation = annotation | None + fields[field_name] = (annotation, field_info) + # Add the output type to the model. - output_type_annotation = Literal[output_type] # type: ignore + output_type_annotation = Literal[output_type] output_type_field = Field( title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) + fields["type"] = (output_type_annotation, output_type_field) # type: ignore + docstring = cls.__doc__ - cls = create_model( - cls.__qualname__, - __base__=cls, - __module__=cls.__module__, - type=(output_type_annotation, output_type_field), - ) - cls.__doc__ = docstring + new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) + new_class.__doc__ = docstring - InvocationRegistry.register_output(cls) + InvocationRegistry.register_output(new_class) - return cls + return new_class return wrapper diff --git a/tests/app/invocations/test_is_optional.py b/tests/app/invocations/test_is_optional.py new file mode 100644 index 0000000000..58e2723ba9 --- /dev/null +++ b/tests/app/invocations/test_is_optional.py @@ -0,0 +1,46 @@ +from typing import Any, Literal, Optional, Union + +import pytest +from pydantic import BaseModel + + +class TestModel(BaseModel): + foo: Literal["bar"] = "bar" + + +@pytest.mark.parametrize( + "input_type, expected", + [ + (str, False), + (list[str], False), + (list[dict[str, Any]], False), + (list[None], False), + (list[dict[str, None]], False), + (Any, False), + (True, False), + (False, False), + (Union[str, False], False), + (Union[str, True], False), + (None, False), + (str | None, True), + (Union[str, None], True), + (Optional[str], True), + (str | int | None, True), + (None | str | int, True), + (Union[None, str], True), + (Optional[str], True), + (Optional[int], True), + (Optional[str], True), + (TestModel | None, True), + (Union[TestModel, None], True), + (Optional[TestModel], True), + ], +) +def test_is_optional(input_type: Any, expected: bool) -> None: + """ + Test the is_optional function. + """ + from invokeai.app.invocations.baseinvocation import is_optional + + result = is_optional(input_type) + assert result == expected, f"Expected {expected} but got {result} for input type {input_type}"