mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-19 11:10:05 -05:00
Compare commits
20 Commits
main
...
psyche/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b07d7846da | ||
|
|
40f35b3088 | ||
|
|
59deef97c5 | ||
|
|
d852ca7a8d | ||
|
|
d24877561d | ||
|
|
8144a263de | ||
|
|
ab086a7069 | ||
|
|
048306b417 | ||
|
|
6eaed9a9cb | ||
|
|
ab9ebef345 | ||
|
|
984dd93798 | ||
|
|
d12fb7db68 | ||
|
|
5d411e446a | ||
|
|
6f128c86b4 | ||
|
|
aca9e44a3a | ||
|
|
e39f035264 | ||
|
|
b612c73954 | ||
|
|
36495b730d | ||
|
|
6ad1948a44 | ||
|
|
f07a46f195 |
@@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
|
||||
@@ -20,6 +20,8 @@ import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_migrate import ConfigMigrator
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
@@ -348,75 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
# When migrating the config file, we should not include currently-set environment variables.
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
@@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
config_from_file.write_file(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
@@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
####################################################
|
||||
# VERSION MIGRATIONS
|
||||
####################################################
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a 4.0.0 config file.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a v4.0.1 config file
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
89
invokeai/app/services/config/config_migrate.py
Normal file
89
invokeai/app/services/config/config_migrate.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationEntry:
|
||||
"""Defines an individual migration."""
|
||||
|
||||
from_version: Version
|
||||
to_version: Version
|
||||
function: MigrationFunction
|
||||
|
||||
|
||||
class ConfigMigrator:
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
_migrations: List[MigrationEntry] = []
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
from_version: str,
|
||||
to_version: str,
|
||||
) -> Callable[[MigrationFunction], MigrationFunction]:
|
||||
"""Define a decorator which registers the migration between two versions."""
|
||||
|
||||
def decorator(function: MigrationFunction) -> MigrationFunction:
|
||||
if any(from_version == m.from_version for m in cls._migrations):
|
||||
raise ValueError(
|
||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||
)
|
||||
cls._migrations.append(
|
||||
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
||||
)
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
|
||||
current_version = Version("3.0.0")
|
||||
for m in migrations:
|
||||
if current_version != m.from_version:
|
||||
raise ValueError(
|
||||
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
|
||||
)
|
||||
current_version = m.to_version
|
||||
|
||||
@classmethod
|
||||
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migration steps to bring config up to latest version.
|
||||
|
||||
:param config: The original configuration.
|
||||
:return: The new configuration, lifted up to the latest version.
|
||||
|
||||
As a side effect, the new configuration will be written to disk.
|
||||
If an inconsistency in the registered migration steps' `from_version`
|
||||
and `to_version` parameters are identified, this will raise a
|
||||
ValueError exception.
|
||||
"""
|
||||
# Sort migrations by version number and raise a ValueError if
|
||||
# any version range overlaps are detected.
|
||||
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
||||
cls._check_for_discontinuities(sorted_migrations)
|
||||
|
||||
if "InvokeAI" in config_dict:
|
||||
version = Version("3.0.0")
|
||||
else:
|
||||
version = Version(config_dict["schema_version"])
|
||||
|
||||
for migration in sorted_migrations:
|
||||
if version == migration.from_version and version < migration.to_version:
|
||||
config_dict = migration.function(config_dict)
|
||||
version = migration.to_version
|
||||
|
||||
config_dict["schema_version"] = str(version)
|
||||
return config_dict
|
||||
@@ -1,22 +1,36 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation
|
||||
from invokeai.app.services.config.config_default import (
|
||||
CONFIG_SCHEMA_VERSION,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
)
|
||||
from invokeai.app.services.config.config_migrate import ConfigMigrator
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
invalid_v4_0_1_config = """
|
||||
schema_version: 4.0.1
|
||||
|
||||
host: "192.168.1.1"
|
||||
port: "ice cream"
|
||||
"""
|
||||
|
||||
v4_config = """
|
||||
schema_version: 4.0.0
|
||||
|
||||
precision: autocast
|
||||
host: "192.168.1.1"
|
||||
port: 8080
|
||||
"""
|
||||
@@ -134,6 +148,16 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
||||
|
||||
|
||||
def test_migrate_v4(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test migration from 4.0.0 to 4.0.1"""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(v4_config)
|
||||
|
||||
conf = load_and_migrate_config(temp_config_file)
|
||||
assert Version(conf.schema_version) >= Version("4.0.1")
|
||||
assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
|
||||
|
||||
|
||||
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test the failed migration of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
@@ -156,12 +180,14 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
|
||||
load_and_migrate_config(temp_config_file)
|
||||
|
||||
|
||||
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
|
||||
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
|
||||
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str):
|
||||
"""Test reading configuration from a file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(invalid_v5_config)
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||
# with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||
with pytest.raises(RuntimeError):
|
||||
load_and_migrate_config(temp_config_file)
|
||||
|
||||
|
||||
@@ -265,58 +291,71 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
||||
InvokeAIArgs.did_parse = False
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="""
|
||||
This test fails when run as part of the full test suite.
|
||||
def test_migration_check() -> None:
|
||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
assert new_config is not None
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
||||
|
||||
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
||||
an app configuration as a test fixture. Pytest executes all test files before running
|
||||
tests, so the app configuration is already initialized by the time this test runs, and
|
||||
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
||||
# Does this execute at compile time or run time?
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
|
||||
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
This test passes when `test_config.py` is tested in isolation.
|
||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
|
||||
|
||||
Perhaps a solution would be to call `get_app_config().parse_args()` in
|
||||
other test files?
|
||||
"""
|
||||
)
|
||||
def test_deny_nodes(patch_rootdir):
|
||||
# Allow integer, string and float, but explicitly deny float
|
||||
allow_deny_nodes_conf = OmegaConf.create(
|
||||
"""
|
||||
InvokeAI:
|
||||
Nodes:
|
||||
allow_nodes:
|
||||
- integer
|
||||
- string
|
||||
- float
|
||||
deny_nodes:
|
||||
- float
|
||||
"""
|
||||
)
|
||||
# must parse config before importing Graph, so its nodes union uses the config
|
||||
get_config.cache_clear()
|
||||
conf = get_config()
|
||||
get_config.cache_clear()
|
||||
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
|
||||
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
# Because there is no version for "*.1" => "*.2", this should fail.
|
||||
with pytest.raises(ValueError):
|
||||
ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
|
||||
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
# should work now, because there is a continuous path to *.3
|
||||
new_config = ConfigMigrator.migrate(new_config)
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
# First clear the config cache to avoid interfering with later tests
|
||||
get_config.cache_clear()
|
||||
# Clear the BaseInvocation's cached typeadapter as well, for same reason.
|
||||
BaseInvocation._typeadapter = None # FIXME: Don't use protected members
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
|
||||
def test_deny_nodes() -> None:
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.allow_nodes = ["integer", "string", "float"]
|
||||
config.deny_nodes = ["float"]
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": IntegerInvocation(value=1)})
|
||||
Graph(nodes={"1": StringInvocation(value="asdf")})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": FloatInvocation(value=1.0)})
|
||||
|
||||
# Also test with a dict input
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
|
||||
does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert does_not_have_float
|
||||
|
||||
Reference in New Issue
Block a user