Compare commits

...

20 Commits

Author SHA1 Message Date
Lincoln Stein
b07d7846da Merge branch 'main' into psyche/fix/nodes-denylist 2024-04-28 15:10:54 -04:00
Lincoln Stein
40f35b3088 fix test_deny_nodes() test to not interfere with later tests 2024-04-28 15:10:43 -04:00
Lincoln Stein
59deef97c5 Merge branch 'main' into lstein/feat/config-migration 2024-04-28 14:31:43 -04:00
Lincoln Stein
d852ca7a8d added test for non-contiguous migration routines 2024-04-28 14:31:38 -04:00
Lincoln Stein
d24877561d reinstated failing deny_nodes validation test for Graph 2024-04-25 00:22:09 -04:00
Lincoln Stein
8144a263de updated and reinstated the test_deny_nodes() unit test 2024-04-24 22:14:48 -04:00
Lincoln Stein
ab086a7069 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:46 -04:00
Lincoln Stein
048306b417 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:12 -04:00
Lincoln Stein
6eaed9a9cb check for strictly contiguous from_version->to_version ranges 2024-04-24 21:36:28 -04:00
psychedelicious
ab9ebef345 tests(config): fix typo 2024-04-23 17:52:51 +10:00
psychedelicious
984dd93798 tests(config): add failing test case to for config migrator 2024-04-23 17:50:31 +10:00
psychedelicious
d12fb7db68 fix(config): fix duplicate migration logic
This was checking a `Version` object against a `MigrationEntry`, but what we want is to check the version object against `MigrationEntry.from_version`
2024-04-23 17:25:53 +10:00
psychedelicious
5d411e446a tidy(config): use a type alias for the migration function 2024-04-23 17:21:05 +10:00
psychedelicious
6f128c86b4 tidy(config): use dataclass for MigrationEntry
The only pydantic usage was to convert strings to `Version` objects. The reason to do this conversion was to allow the register decorator to accept strings. MigrationEntry is only created inside this class, so we can just create versions from each migration when instantiating MigrationEntry instead.

Also, pydantic doesn't provide runtime time checking for arbitrary classes like Version, so we don't get any real benefit.
2024-04-23 17:19:54 +10:00
psychedelicious
aca9e44a3a fix(config): use TypeAlias instead of TypeVar
TypeVar is for generics, but the usage here is as an alias
2024-04-23 17:12:19 +10:00
psychedelicious
e39f035264 tidy(config): removed extraneous ABC
We don't need separate implementations for this class, let's not complicate it with an ABC
2024-04-23 17:11:13 +10:00
psychedelicious
b612c73954 tidy(config): remove unused TYPE_CHECKING block 2024-04-23 17:09:50 +10:00
Lincoln Stein
36495b730d use packaging.version rather than version-parse 2024-04-18 23:07:54 -04:00
Lincoln Stein
6ad1948a44 add InvokeAIAppConfig schema migration system 2024-04-18 21:33:54 -04:00
psychedelicious
f07a46f195 fix(nodes): respect nodes denylist
In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality.

However, because the corresponding test was marked as `xfail`, we didn't catch the issue.

- Fix the nodes denylist handling
- Update the tests
2024-03-08 12:52:41 +11:00
4 changed files with 262 additions and 139 deletions

View File

@@ -183,7 +183,7 @@ class BaseInvocation(ABC, BaseModel):
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
"InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter

View File

@@ -20,6 +20,8 @@ import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_migrate import ConfigMigrator
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
@@ -348,75 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
assert isinstance(loaded_config_dict, dict)
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
@@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
default_config.write_file(config.config_file_path, as_example=False)
return config
####################################################
# VERSION MIGRATIONS
####################################################
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@@ -0,0 +1,89 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema.
"""
from dataclasses import dataclass
from typing import Any, Callable, List, TypeAlias
from packaging.version import Version
AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
_migrations: List[MigrationEntry] = []
@classmethod
def register(
cls,
from_version: str,
to_version: str,
) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction:
if any(from_version == m.from_version for m in cls._migrations):
raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
)
cls._migrations.append(
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
)
return function
return decorator
@staticmethod
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
current_version = Version("3.0.0")
for m in migrations:
if current_version != m.from_version:
raise ValueError(
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
)
current_version = m.to_version
@classmethod
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
:param config: The original configuration.
:return: The new configuration, lifted up to the latest version.
As a side effect, the new configuration will be written to disk.
If an inconsistency in the registered migration steps' `from_version`
and `to_version` parameters are identified, this will raise a
ValueError exception.
"""
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected.
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
cls._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict:
version = Version("3.0.0")
else:
version = Version(config_dict["schema_version"])
for migration in sorted_migrations:
if version == migration.from_version and version < migration.to_version:
config_dict = migration.function(config_dict)
version = migration.to_version
config_dict["schema_version"] = str(version)
return config_dict

View File

@@ -1,22 +1,36 @@
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from typing import Any, Generator
import pytest
from omegaconf import OmegaConf
from packaging.version import Version
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation, StringInvocation
from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
)
from invokeai.app.services.config.config_migrate import ConfigMigrator
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
invalid_v4_0_1_config = """
schema_version: 4.0.1
host: "192.168.1.1"
port: "ice cream"
"""
v4_config = """
schema_version: 4.0.0
precision: autocast
host: "192.168.1.1"
port: 8080
"""
@@ -134,6 +148,16 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_migrate_v4(tmp_path: Path, patch_rootdir: None):
"""Test migration from 4.0.0 to 4.0.1"""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v4_config)
conf = load_and_migrate_config(temp_config_file)
assert Version(conf.schema_version) >= Version("4.0.1")
assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
@@ -156,12 +180,14 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
load_and_migrate_config(temp_config_file)
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_v5_config)
temp_config_file.write_text(config_content)
with pytest.raises(RuntimeError, match="Invalid schema version"):
# with pytest.raises(RuntimeError, match="Invalid schema version"):
with pytest.raises(RuntimeError):
load_and_migrate_config(temp_config_file)
@@ -265,58 +291,71 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
def test_migration_check() -> None:
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
# Does this execute at compile time or run time?
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
This test passes when `test_config.py` is tested in isolation.
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
Perhaps a solution would be to call `get_app_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
get_config.cache_clear()
conf = get_config()
get_config.cache_clear()
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
# Because there is no version for "*.1" => "*.2", this should fail.
with pytest.raises(ValueError):
ConfigMigrator.migrate({"schema_version": "4.0.0"})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
from invokeai.app.invocations.baseinvocation import BaseInvocation
# should work now, because there is a continuous path to *.3
new_config = ConfigMigrator.migrate(new_config)
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
# First clear the config cache to avoid interfering with later tests
get_config.cache_clear()
# Clear the BaseInvocation's cached typeadapter as well, for same reason.
BaseInvocation._typeadapter = None # FIXME: Don't use protected members
assert has_integer
assert has_string
assert not has_float
def test_deny_nodes() -> None:
with clear_config():
config = get_config()
config.allow_nodes = ["integer", "string", "float"]
config.deny_nodes = ["float"]
# confirm graph validation fails when using denied node
Graph(nodes={"1": IntegerInvocation(value=1)})
Graph(nodes={"1": StringInvocation(value="asdf")})
with pytest.raises(ValidationError):
Graph(nodes={"1": FloatInvocation(value=1.0)})
# Also test with a dict input
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
does_not_have_float = len([i for i in all_invocations if i.get_type() == "float"]) == 0
assert has_integer
assert has_string
assert does_not_have_float