multiple small fixes suggested in reviews from psychedelicious and ryan

This commit is contained in:
Lincoln Stein
2023-11-10 18:25:37 -05:00
parent fdaa661245
commit 0544917161
4 changed files with 48 additions and 37 deletions

View File

@@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
BaseModelType,
DiffusersConfig,
MainDiffusersConfig,
ModelType,
TextualInversionConfig,
VaeDiffusersConfig,
@@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = store.get_model("key1")
assert config.name == "old name"
store.rename_model("key1", "new name")
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config()
@@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
def test_filter(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
@@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(model_type=ModelType("main"))
matches = store.search_by_attr(model_type=ModelType("main"))
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_name(model_type=ModelType("vae"))
matches = store.search_by_attr(model_type=ModelType("vae"))
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_name(model_type=BaseModelType("sd-2"))
matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1
@@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
def test_filter_2(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
)
config3 = DiffusersConfig(
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config4 = DiffusersConfig(
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sd-2"),
@@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(
matches = store.search_by_attr(
model_type=ModelType("main"),
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_name(
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
)
assert len(matches) == 2
matches = store.search_by_name(
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
model_name="dup_name1",