mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-09 20:05:04 -05:00
feat(nodes): improved pydantic type annotation massaging
When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`.
This results in an error when doing a ser/de round trip. Here's what we end up doing:
```py
from pydantic import BaseModel, Field
class MyModel(BaseModel):
foo: str = Field(default=None)
```
And here is a simple round-trip, which should not error but which does:
```py
MyModel(**MyModel().model_dump())
# ValidationError: 1 validation error for MyModel
# foo
# Input should be a valid string [type=string_type, input_value=None, input_type=NoneType]
# For further information visit https://errors.pydantic.dev/2.11/v/string_type
```
To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation `<original type> | None`.
This prevents the error during deserialization.
This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as `<original type> | None`, reflecting the overrides.
This means the autogenerated types for fields have also changed for fields without defaults:
```ts
// Old
image?: components["schemas"]["ImageField"];
// New
image?: components["schemas"]["ImageField"] | null;
```
This does not break anything on the frontend.
This commit is contained in:
@@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -489,6 +491,18 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
return None
|
||||
|
||||
|
||||
def is_optional(annotation: Any) -> bool:
|
||||
"""
|
||||
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
|
||||
"""
|
||||
origin = typing.get_origin(annotation)
|
||||
# PEP 604 unions (int|None) have origin types.UnionType
|
||||
is_union = origin is typing.Union or origin is types.UnionType
|
||||
if not is_union:
|
||||
return False
|
||||
return any(arg is type(None) for arg in typing.get_args(annotation))
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
@@ -523,6 +537,18 @@ def invocation(
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
if field_info.default is None and not is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
@@ -557,11 +583,17 @@ def invocation(
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_annotation = Literal[invocation_type]
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
# pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers
|
||||
# don't get confused by something like this:
|
||||
# foo: str = Field() <-- this is a FieldInfo, not a str
|
||||
# Unfortunately this means we need to use type: ignore here to avoid type-checker errors
|
||||
fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore
|
||||
|
||||
# Validate the `invoke()` method is implemented
|
||||
if "invoke" in cls.__abstractmethods__:
|
||||
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
|
||||
@@ -583,17 +615,12 @@ def invocation(
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
|
||||
new_class.__doc__ = docstring
|
||||
|
||||
InvocationRegistry.register_invocation(cls)
|
||||
InvocationRegistry.register_invocation(new_class)
|
||||
|
||||
return cls
|
||||
return new_class
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -618,23 +645,32 @@ def invocation_output(
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
if field_info.default is not PydanticUndefined and is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add the output type to the model.
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_annotation = Literal[output_type]
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
fields["type"] = (output_type_annotation, output_type_field) # type: ignore
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
|
||||
new_class.__doc__ = docstring
|
||||
|
||||
InvocationRegistry.register_output(cls)
|
||||
InvocationRegistry.register_output(new_class)
|
||||
|
||||
return cls
|
||||
return new_class
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user