From 90232806d91f426bb4096f33d8474bdf75a19fc0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 Apr 2024 17:30:57 +1000 Subject: [PATCH] feat(nodes): add validation for invoke method return types --- invokeai/app/invocations/baseinvocation.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 78d85b4d33..0546479774 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -486,6 +486,26 @@ def invocation( title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) + # Validate the `invoke()` method is implemented + if "invoke" in cls.__abstractmethods__: + raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method') + + # And validate that `invoke()` returns a subclass of `BaseInvocationOutput + invoke_return_annotation = signature(cls.invoke).return_annotation + try: + assert invoke_return_annotation is not BaseInvocationOutput + # TODO(psyche): If `invoke()` is not defined, `return_annotation` ends up as the string + # "BaseInvocationOutput". This may be a pydantic bug: https://github.com/pydantic/pydantic/issues/7978 + # I cannot reproduce this in a simple test case, so I'm not sure how to fix it. + # + # This check should be in a try block, not a conditional, because `issubclass` errors if the first arg is + # not a class (e.g. the string "BaseInvocationOutput"). + assert issubclass(invoke_return_annotation, BaseInvocationOutput) + except Exception: + raise ValueError( + f'Invocation "{invocation_type}" must have a return annotation of a subclass of BaseInvocationOutput (got "{invoke_return_annotation}")' + ) + docstring = cls.__doc__ cls = create_model( cls.__qualname__,