mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
feat(nodes): validate default values for all fields
This prevents issues where the node is defined with an invalid default value, which would guarantee an error during a ser/de roundtrip. - Upstream issue requesting this functionality be built-in to pydantic: https://github.com/pydantic/pydantic/issues/8722 - Upstream PR that implements the functionality: https://github.com/pydantic/pydantic-core/pull/1593
This commit is contained in:
@@ -491,6 +491,31 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
return None
|
||||
|
||||
|
||||
class NoDefaultSentinel:
|
||||
pass
|
||||
|
||||
|
||||
def validate_field_default(field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo) -> None:
|
||||
"""Validates the default value of a field against its pydantic field definition."""
|
||||
|
||||
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
|
||||
|
||||
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
|
||||
# We store the original default value in the json_schema_extra dict, so we can validate it here.
|
||||
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
|
||||
|
||||
if orig_default is NoDefaultSentinel:
|
||||
return
|
||||
|
||||
TempDefaultValidator = create_model("TempDefaultValidator", field_to_validate=(annotation, field_info))
|
||||
|
||||
# Validate the default value against the annotation
|
||||
try:
|
||||
TempDefaultValidator(field_to_validate=orig_default)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Default value for field {field_name} on invocation {invocation_type} is invalid, {e}") from e
|
||||
|
||||
|
||||
def is_optional(annotation: Any) -> bool:
|
||||
"""
|
||||
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
|
||||
@@ -545,8 +570,12 @@ def invocation(
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
validate_field_default(field_name, invocation_type, annotation, field_info)
|
||||
|
||||
if field_info.default is None and not is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
|
||||
Reference in New Issue
Block a user