mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
multiple small fixes suggested in reviews from psychedelicious and ryan
This commit is contained in:
@@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
DiffusersConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
@@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
def test_rename(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = store.get_model("key1")
|
||||
assert config.name == "old name"
|
||||
|
||||
store.rename_model("key1", "new name")
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
|
||||
def test_unknown_key(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
@@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
|
||||
|
||||
|
||||
def test_filter(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
@@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_name(model_type=ModelType("main"))
|
||||
matches = store.search_by_attr(model_type=ModelType("main"))
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_name(model_type=ModelType("vae"))
|
||||
matches = store.search_by_attr(model_type=ModelType("vae"))
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_name(model_type=BaseModelType("sd-2"))
|
||||
matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
|
||||
|
||||
matches = store.search_by_hash("CONFIG1HASH")
|
||||
assert len(matches) == 1
|
||||
@@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
|
||||
|
||||
def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = DiffusersConfig(
|
||||
config3 = MainDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config4 = DiffusersConfig(
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
@@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
model_type=ModelType("main"),
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("main"),
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("vae"),
|
||||
model_name="dup_name1",
|
||||
|
||||
Reference in New Issue
Block a user