feat(nodes): move invocation/output registration to separate class

This commit is contained in:
psychedelicious
2025-03-11 10:02:22 +10:00
parent 7be87c8048
commit 6155f9ff9e
4 changed files with 101 additions and 97 deletions

View File

@@ -100,35 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.get_typeadapter.cache_clear()
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
@lru_cache(maxsize=1)
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -171,67 +142,16 @@ class BaseInvocation(ABC, BaseModel):
All invocations must use the `@invocation` decorator to provide their unique type.
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.get_typeadapter.cache_clear()
@classmethod
@lru_cache(maxsize=1)
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")])
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -329,6 +249,87 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
class InvocationRegistry:
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.get_invocation_typeadapter.cache_clear()
@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[type[BaseInvocation]] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in cls.get_invocation_classes()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in cls.get_invocation_classes())
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.get_output_typeadapter.cache_clear()
@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
@@ -442,8 +443,8 @@ def invocation(
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
@@ -528,8 +529,7 @@ def invocation(
)
cls.__doc__ = docstring
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
BaseInvocation.register_invocation(cls) # type: ignore
InvocationRegistry.register_invocation(cls)
return cls
@@ -554,7 +554,7 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in BaseInvocationOutput.get_output_types():
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
@@ -575,7 +575,7 @@ def invocation_output(
)
cls.__doc__ = docstring
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
InvocationRegistry.register_output(cls)
return cls

View File

@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationRegistry,
invocation,
invocation_output,
)
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return InvocationRegistry.get_output_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}

View File

@@ -4,7 +4,10 @@ from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.baseinvocation import (
InvocationRegistry,
UIConfigBase,
)
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
@@ -56,14 +59,14 @@ def get_openapi_func(
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
for invocation in InvocationRegistry.get_invocation_classes():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)

View File

@@ -5,7 +5,7 @@ from typing import Any
import pytest
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import InvocationRegistry
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
@@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):
# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
# subsequent graph validations
BaseInvocation.get_typeadapter.cache_clear()
InvocationRegistry.get_invocation_typeadapter.cache_clear()
# confirm graph validation fails when using denied node
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
@@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir):
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}})
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
all_invocations = InvocationRegistry.get_invocation_classes()
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
@@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):
# Reset the config so that it doesn't affect other tests
get_config.cache_clear()
BaseInvocation.get_typeadapter.cache_clear()
InvocationRegistry.get_invocation_typeadapter.cache_clear()