diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index f86dc2d9a1..78d85b4d33 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -192,12 +192,19 @@ class BaseInvocation(ABC, BaseModel): """Gets a pydantc TypeAdapter for the union of all invocation types.""" if not cls._typeadapter or cls._typeadapter_needs_update: AnyInvocation = TypeAliasType( - "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + "AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")] ) cls._typeadapter = TypeAdapter(AnyInvocation) cls._typeadapter_needs_update = False return cls._typeadapter + @classmethod + def invalidate_typeadapter(cls) -> None: + """Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or + denylist is changed, this should be called to ensure the typeadapter is updated and validation respects + the updated allowlist and denylist.""" + cls._typeadapter_needs_update = True + @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: """Gets all invocations, respecting the allowlist and denylist.""" diff --git a/tests/test_config.py b/tests/test_config.py index a6ea2a3480..c7149edc73 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,15 +3,16 @@ from tempfile import TemporaryDirectory from typing import Any import pytest -from omegaconf import OmegaConf from pydantic import ValidationError +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, get_config, load_and_migrate_config, ) +from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs v4_config = """ @@ -265,57 +266,29 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False -@pytest.mark.xfail( - reason=""" - This test fails when run as part of the full test suite. - - This test needs to deny nodes from being included in the InvocationsUnion by providing - an app configuration as a test fixture. Pytest executes all test files before running - tests, so the app configuration is already initialized by the time this test runs, and - the InvocationUnion is already created and the denied nodes are not omitted from it. - - This test passes when `test_config.py` is tested in isolation. - - Perhaps a solution would be to call `get_app_config().parse_args()` in - other test files? - """ -) def test_deny_nodes(patch_rootdir): # Allow integer, string and float, but explicitly deny float - allow_deny_nodes_conf = OmegaConf.create( - """ - InvokeAI: - Nodes: - allow_nodes: - - integer - - string - - float - deny_nodes: - - float - """ - ) - # must parse config before importing Graph, so its nodes union uses the config - get_config.cache_clear() conf = get_config() - get_config.cache_clear() - conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[]) - from invokeai.app.services.shared.graph import Graph + conf.allow_nodes = ["integer", "string", "float"] + conf.deny_nodes = ["float"] + + # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for + # subsequent graph validations + BaseInvocation.invalidate_typeadapter() # confirm graph validation fails when using denied node - Graph(nodes={"1": {"id": "1", "type": "integer"}}) - Graph(nodes={"1": {"id": "1", "type": "string"}}) + Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) + Graph.model_validate({"nodes": {"1": {"id": "1", "type": "string"}}}) with pytest.raises(ValidationError): - Graph(nodes={"1": {"id": "1", "type": "float"}}) - - from invokeai.app.invocations.baseinvocation import BaseInvocation + Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}}) # confirm invocations union will not have denied nodes all_invocations = BaseInvocation.get_invocations() - has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 - has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 - has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1 + has_float = len([i for i in all_invocations if i.get_type() == "float"]) == 1 assert has_integer assert has_string