mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
33 Commits
v6.5.0
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b76d112be | ||
|
|
6e40142a59 | ||
|
|
00ccd73d53 | ||
|
|
964adb817c | ||
|
|
7d8b011f89 | ||
|
|
d487102904 | ||
|
|
4c081d58e0 | ||
|
|
18b5aafade | ||
|
|
6946a3871f | ||
|
|
fc23b16a73 | ||
|
|
31f63028fd | ||
|
|
2dd42d0917 | ||
|
|
2bba7f38b9 | ||
|
|
a48abfacf4 | ||
|
|
d5aee87684 | ||
|
|
36b14343c7 | ||
|
|
59deef97c5 | ||
|
|
d852ca7a8d | ||
|
|
d24877561d | ||
|
|
8144a263de | ||
|
|
ab086a7069 | ||
|
|
048306b417 | ||
|
|
6eaed9a9cb | ||
|
|
ab9ebef345 | ||
|
|
984dd93798 | ||
|
|
d12fb7db68 | ||
|
|
5d411e446a | ||
|
|
6f128c86b4 | ||
|
|
aca9e44a3a | ||
|
|
e39f035264 | ||
|
|
b612c73954 | ||
|
|
36495b730d | ||
|
|
6ad1948a44 |
@@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser
|
||||
|
||||
from .config_default import InvokeAIAppConfig, get_config
|
||||
from .config_default import InvokeAIAppConfig
|
||||
from .config_migrate import get_config
|
||||
|
||||
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]
|
||||
|
||||
@@ -12,6 +12,10 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pydoc
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
@@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
|
||||
ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigMigration:
|
||||
"""Defines an individual config migration."""
|
||||
|
||||
from_version: Version
|
||||
to_version: Version
|
||||
function: ConfigMigrationFunction
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
@@ -3,11 +3,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -16,9 +13,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@@ -346,169 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
# When migrating the config file, we should not include currently-set environment variables.
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# Create the example config file, with some extra example values provided
|
||||
example_config = DefaultInvokeAIAppConfig()
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
# We should never write env vars to the config file
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
|
||||
177
invokeai/app/services/config/config_migrate.py
Normal file
177
invokeai/app/services/config/config_migrate.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||
"""
|
||||
|
||||
import locale
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable
|
||||
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
||||
|
||||
|
||||
class ConfigMigrator:
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[ConfigMigration] = set()
|
||||
|
||||
def register(self, migration: ConfigMigration) -> None:
|
||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||
if migration_from_already_registered or migration_to_already_registered:
|
||||
raise ValueError(
|
||||
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
|
||||
)
|
||||
self._migrations.add(migration)
|
||||
|
||||
@staticmethod
|
||||
def _check_for_discontinuities(migrations: Iterable[ConfigMigration]) -> None:
|
||||
current_version = Version("3.0.0")
|
||||
sorted_migrations = sorted(migrations, key=lambda x: x.from_version)
|
||||
for m in sorted_migrations:
|
||||
if current_version != m.from_version:
|
||||
raise ValueError(
|
||||
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
|
||||
)
|
||||
current_version = m.to_version
|
||||
|
||||
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migrations to bring config up to latest version.
|
||||
|
||||
Args:
|
||||
original_config: The original configuration.
|
||||
|
||||
Returns:
|
||||
The new configuration, lifted up to the latest version.
|
||||
"""
|
||||
|
||||
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
|
||||
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
||||
self._check_for_discontinuities(sorted_migrations)
|
||||
|
||||
# Do not mutate the incoming dict - we don't know who else may be using it
|
||||
migrated_config = deepcopy(original_config)
|
||||
|
||||
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
|
||||
if "InvokeAI" in migrated_config:
|
||||
version = Version("3.0.0")
|
||||
else:
|
||||
version = Version(migrated_config["schema_version"])
|
||||
|
||||
for migration in sorted_migrations:
|
||||
if version == migration.from_version:
|
||||
migrated_config = migration.function(migrated_config)
|
||||
version = migration.to_version
|
||||
|
||||
# We must end on the latest version
|
||||
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
|
||||
return migrated_config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
migrator = ConfigMigrator()
|
||||
migrator.register(config_migration_1)
|
||||
migrator.register(config_migration_2)
|
||||
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
# TODO(psyche): This must must be in this file to avoid circular dependencies
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
tmpdir = TemporaryDirectory()
|
||||
config._root = Path(tmpdir.name)
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# Create the example config file, with some extra example values provided
|
||||
example_config = DefaultInvokeAIAppConfig()
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
config_from_file.write_file(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
# We should never write env vars to the config file
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
102
invokeai/app/services/config/migrations.py
Normal file
102
invokeai/app/services/config/migrations.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Schema migrations to perform on an InvokeAIAppConfig object.
|
||||
|
||||
The Migrations class defined in this module defines a series of
|
||||
schema version migration steps for the InvokeAIConfig object.
|
||||
|
||||
To define a new migration, add a migration function to
|
||||
Migrations.load_migrations() following the existing examples.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||
|
||||
from .config_default import InvokeAIAppConfig
|
||||
|
||||
|
||||
def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""Migrate a v3.0.0 config dict to v4.0.0.
|
||||
|
||||
Changes in this migration:
|
||||
- `outdir` was renamed to `outputs_dir`
|
||||
- `max_cache_size` was renamed to `ram`
|
||||
- `max_vram_cache_size` was renamed to `vram`
|
||||
- `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
|
||||
to the database
|
||||
- `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
|
||||
|
||||
Args:
|
||||
config_dict: The v3.0.0 config dict to migrate.
|
||||
|
||||
Returns:
|
||||
The migrated v4.0.0 config dict.
|
||||
"""
|
||||
migrated_config: AppConfigDict = {}
|
||||
for _category_name, category_dict in original_config["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
migrated_config["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
migrated_config["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
migrated_config["vram"] = v
|
||||
if k == "conf_path":
|
||||
migrated_config["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
migrated_config["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
migrated_config["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
migrated_config[k] = v
|
||||
migrated_config["schema_version"] = "4.0.0"
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_1 = ConfigMigration(
|
||||
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
||||
)
|
||||
|
||||
|
||||
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""Migrate a v4.0.0 config dict to v4.0.1.
|
||||
|
||||
Changes in this migration:
|
||||
- `precision: "autocast"` was removed, fall back to "auto"
|
||||
|
||||
Args:
|
||||
config_dict: The v4.0.0 config dict to migrate.
|
||||
|
||||
Returns:
|
||||
The migrated v4.0.1 config dict.
|
||||
"""
|
||||
migrated_config: AppConfigDict = {}
|
||||
for k, v in original_config.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
migrated_config["precision"] = "auto"
|
||||
# skip unknown fields
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
migrated_config[k] = v
|
||||
migrated_config["schema_version"] = "4.0.1"
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_2 = ConfigMigration(
|
||||
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from torch import Tensor
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
|
||||
@@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
Range,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
|
||||
@@ -180,8 +180,7 @@ import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
|
||||
try:
|
||||
import syslog
|
||||
|
||||
@@ -1,22 +1,35 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||
from invokeai.app.services.config.config_default import (
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
)
|
||||
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
|
||||
from invokeai.app.services.config.migrations import migrate_v300_to_v400, migrate_v400_to_v401
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
invalid_v4_0_1_config = """
|
||||
schema_version: 4.0.1
|
||||
|
||||
host: "192.168.1.1"
|
||||
port: "ice cream"
|
||||
"""
|
||||
|
||||
v4_config = """
|
||||
schema_version: 4.0.0
|
||||
|
||||
precision: autocast
|
||||
host: "192.168.1.1"
|
||||
port: 8080
|
||||
"""
|
||||
@@ -59,20 +72,104 @@ i like turtles
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
||||
def test_config_migrator_registers_migrations() -> None:
|
||||
"""Test that the config migrator registers migrations."""
|
||||
migrator = ConfigMigrator()
|
||||
|
||||
def migration_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return config
|
||||
|
||||
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func)
|
||||
migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_func)
|
||||
|
||||
migrator.register(migration_1)
|
||||
assert migrator._migrations == {migration_1}
|
||||
migrator.register(migration_2)
|
||||
assert migrator._migrations == {migration_1, migration_2}
|
||||
|
||||
|
||||
def test_path_resolution_root_not_set(patch_rootdir: None):
|
||||
def test_config_migrator_rejects_duplicate_migrations() -> None:
|
||||
"""Test that the config migrator rejects duplicate migrations."""
|
||||
migrator = ConfigMigrator()
|
||||
|
||||
def migration_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return config
|
||||
|
||||
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func)
|
||||
migrator.register(migration_1)
|
||||
|
||||
# Re-register the same migration
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"A migration from {migration_1.from_version} or to {migration_1.to_version} has already been registered.",
|
||||
):
|
||||
migrator.register(migration_1)
|
||||
|
||||
# Register a migration with the same from_version
|
||||
migration_2 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("5.0.0"), function=migration_func)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"A migration from {migration_2.from_version} or to {migration_2.to_version} has already been registered.",
|
||||
):
|
||||
migrator.register(migration_2)
|
||||
|
||||
# Register a migration with the same to_version
|
||||
migration_3 = ConfigMigration(from_version=Version("3.0.1"), to_version=Version("4.0.0"), function=migration_func)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"A migration from {migration_3.from_version} or to {migration_3.to_version} has already been registered.",
|
||||
):
|
||||
migrator.register(migration_3)
|
||||
|
||||
|
||||
def test_config_migrator_contiguous_migrations() -> None:
|
||||
"""Test that the config migrator requires contiguous migrations."""
|
||||
migrator = ConfigMigrator()
|
||||
|
||||
def migration_1_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return {"schema_version": "4.0.0"}
|
||||
|
||||
def migration_3_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return {"schema_version": "6.0.0"}
|
||||
|
||||
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func)
|
||||
migration_3 = ConfigMigration(from_version=Version("5.0.0"), to_version=Version("6.0.0"), function=migration_3_func)
|
||||
|
||||
migrator.register(migration_1)
|
||||
migrator.register(migration_3)
|
||||
with pytest.raises(ValueError, match="Migration functions are not continuous"):
|
||||
migrator._check_for_discontinuities(migrator._migrations)
|
||||
|
||||
|
||||
def test_config_migrator_runs_migrations() -> None:
|
||||
"""Test that the config migrator runs migrations."""
|
||||
migrator = ConfigMigrator()
|
||||
|
||||
def migration_1_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return {"schema_version": "4.0.0"}
|
||||
|
||||
def migration_2_func(config: AppConfigDict) -> AppConfigDict:
|
||||
return {"schema_version": "5.0.0"}
|
||||
|
||||
migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func)
|
||||
migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_2_func)
|
||||
|
||||
migrator.register(migration_1)
|
||||
migrator.register(migration_2)
|
||||
|
||||
original_config = {"schema_version": "3.0.0"}
|
||||
migrated_config = migrator.run_migrations(original_config)
|
||||
assert migrated_config == {"schema_version": "5.0.0"}
|
||||
|
||||
|
||||
def test_path_resolution_root_not_set():
|
||||
"""Test path resolutions when the root is not explicitly set."""
|
||||
config = InvokeAIAppConfig()
|
||||
expected_root = InvokeAIAppConfig.find_root()
|
||||
assert config.root_path == expected_root
|
||||
|
||||
|
||||
def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
def test_read_config_from_file(tmp_path: Path):
|
||||
"""Test reading configuration from a file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(v4_config)
|
||||
@@ -82,12 +179,10 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert config.port == 8080
|
||||
|
||||
|
||||
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
def test_migration_1_migrates_settings(tmp_path: Path):
|
||||
"""Test reading configuration from a file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(v3_config)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config))
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert config.outputs_dir == Path("/some/outputs/dir")
|
||||
assert config.host == "192.168.1.1"
|
||||
assert config.port == 8080
|
||||
@@ -111,20 +206,18 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
("full/custom/path", Path("full/custom/path"), True),
|
||||
],
|
||||
)
|
||||
def test_migrate_v3_legacy_conf_dir_defaults(
|
||||
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
def test_migration_1_handles_legacy_conf_dir_defaults(
|
||||
legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
):
|
||||
"""Test reading configuration from a file."""
|
||||
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(config_content))
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert config.legacy_conf_dir == expected_value
|
||||
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
|
||||
|
||||
|
||||
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
def test_load_and_migrate_backs_up_file(tmp_path: Path):
|
||||
"""Test the backup of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(v3_config)
|
||||
@@ -134,7 +227,15 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
||||
|
||||
|
||||
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
||||
def test_migration_2_migrates_settings():
|
||||
"""Test migration from 4.0.0 to 4.0.1"""
|
||||
migrated_config_dict = migrate_v400_to_v401(yaml.safe_load(v4_config))
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert Version(config.schema_version) == Version("4.0.1")
|
||||
assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
|
||||
|
||||
|
||||
def test_load_and_migrate_failed_migrate_backup(tmp_path: Path):
|
||||
"""Test the failed migration of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(v3_config_with_bad_values)
|
||||
@@ -147,7 +248,7 @@ def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
||||
assert temp_config_file.read_text() == v3_config_with_bad_values
|
||||
|
||||
|
||||
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
|
||||
def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path):
|
||||
"""Test reading configuration from a file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(invalid_config)
|
||||
@@ -156,16 +257,18 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
|
||||
load_and_migrate_config(temp_config_file)
|
||||
|
||||
|
||||
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
|
||||
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
|
||||
def test_bails_on_config_with_unsupported_version(tmp_path: Path, config_content: str):
|
||||
"""Test reading configuration from a file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(invalid_v5_config)
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||
# with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||
with pytest.raises(RuntimeError):
|
||||
load_and_migrate_config(temp_config_file)
|
||||
|
||||
|
||||
def test_write_config_to_file(patch_rootdir: None):
|
||||
def test_write_config_to_file():
|
||||
"""Test writing configuration to a file, checking for correct output."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
temp_config_path = Path(tmpdir) / "invokeai.yaml"
|
||||
@@ -180,7 +283,7 @@ def test_write_config_to_file(patch_rootdir: None):
|
||||
assert "port: 8080" in content
|
||||
|
||||
|
||||
def test_update_config_with_dict(patch_rootdir: None):
|
||||
def test_update_config_with_dict():
|
||||
"""Test updating the config with a dictionary."""
|
||||
config = InvokeAIAppConfig()
|
||||
update_dict = {"host": "10.10.10.10", "port": 6060}
|
||||
@@ -189,7 +292,7 @@ def test_update_config_with_dict(patch_rootdir: None):
|
||||
assert config.port == 6060
|
||||
|
||||
|
||||
def test_update_config_with_object(patch_rootdir: None):
|
||||
def test_update_config_with_object():
|
||||
"""Test updating the config with another config object."""
|
||||
config = InvokeAIAppConfig()
|
||||
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
|
||||
@@ -198,7 +301,7 @@ def test_update_config_with_object(patch_rootdir: None):
|
||||
assert config.port == 6060
|
||||
|
||||
|
||||
def test_set_and_resolve_paths(patch_rootdir: None):
|
||||
def test_set_and_resolve_paths():
|
||||
"""Test setting root and resolving paths based on it."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config = InvokeAIAppConfig()
|
||||
@@ -207,7 +310,7 @@ def test_set_and_resolve_paths(patch_rootdir: None):
|
||||
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
|
||||
|
||||
|
||||
def test_singleton_behavior(patch_rootdir: None):
|
||||
def test_singleton_behavior():
|
||||
"""Test that get_config always returns the same instance."""
|
||||
get_config.cache_clear()
|
||||
config1 = get_config()
|
||||
@@ -216,13 +319,13 @@ def test_singleton_behavior(patch_rootdir: None):
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_default_config(patch_rootdir: None):
|
||||
def test_default_config():
|
||||
"""Test that the default config is as expected."""
|
||||
config = DefaultInvokeAIAppConfig()
|
||||
assert config.host == "127.0.0.1"
|
||||
|
||||
|
||||
def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
def test_env_vars(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
"""Test that environment variables are merged into the config"""
|
||||
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
||||
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
|
||||
@@ -233,7 +336,7 @@ def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
assert config.root_path == tmp_path
|
||||
|
||||
|
||||
def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
def test_get_config_writing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
"""Test that get_config writes the appropriate files to disk"""
|
||||
# Trick the config into thinking it has already parsed args - this triggers the writing of the config file
|
||||
InvokeAIArgs.did_parse = True
|
||||
@@ -265,58 +368,39 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
||||
InvokeAIArgs.did_parse = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="""
|
||||
This test fails when run as part of the full test suite.
|
||||
|
||||
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
||||
an app configuration as a test fixture. Pytest executes all test files before running
|
||||
tests, so the app configuration is already initialized by the time this test runs, and
|
||||
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
||||
|
||||
This test passes when `test_config.py` is tested in isolation.
|
||||
|
||||
Perhaps a solution would be to call `get_app_config().parse_args()` in
|
||||
other test files?
|
||||
"""
|
||||
Currently this test is failing due to an issue described in issue #5983.
|
||||
"""
|
||||
)
|
||||
def test_deny_nodes(patch_rootdir):
|
||||
# Allow integer, string and float, but explicitly deny float
|
||||
allow_deny_nodes_conf = OmegaConf.create(
|
||||
"""
|
||||
InvokeAI:
|
||||
Nodes:
|
||||
allow_nodes:
|
||||
- integer
|
||||
- string
|
||||
- float
|
||||
deny_nodes:
|
||||
- float
|
||||
"""
|
||||
)
|
||||
# must parse config before importing Graph, so its nodes union uses the config
|
||||
get_config.cache_clear()
|
||||
conf = get_config()
|
||||
get_config.cache_clear()
|
||||
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
def test_deny_nodes():
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.allow_nodes = ["integer", "string", "float"]
|
||||
config.deny_nodes = ["float"]
|
||||
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
# confirm graph validation fails when using denied node
|
||||
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
with pytest.raises(ValidationError):
|
||||
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
|
||||
# confirm invocations union will not have denied nodes
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
|
||||
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
||||
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
||||
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
||||
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
assert has_integer
|
||||
assert has_string
|
||||
assert not has_float
|
||||
|
||||
Reference in New Issue
Block a user