mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
add class docstring and blackify
This commit is contained in:
@@ -79,4 +79,3 @@ class ModelConfigStore(ABC):
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,38 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Implementation of ModelConfigStore using a YAML file."""
|
||||
"""
|
||||
Implementation of ModelConfigStore using a YAML file.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML
|
||||
store = ModelConfigStoreYAML("./configs/models.yaml")
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
name='old name',
|
||||
base_model='sd-1',
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
)
|
||||
|
||||
# adding
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
config.name='new name'
|
||||
store.update_model('key1', config)
|
||||
|
||||
# checking for existence
|
||||
if store.exists('key1'):
|
||||
print("yes")
|
||||
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
"""
|
||||
|
||||
import threading
|
||||
import yaml
|
||||
@@ -33,7 +66,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
def __init__(self, config_file: Path):
|
||||
"""Initialize ModelConfigStore object with a .yaml file."""
|
||||
super().__init__()
|
||||
self._filename = Path(config_file)
|
||||
self._filename = Path(config_file).absolute() # don't let chdir mess us up!
|
||||
self._lock = threading.RLock()
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
@@ -51,7 +84,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
def _commit(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
newfile = Path(str(self._filename)+'.new')
|
||||
newfile = Path(str(self._filename) + ".new")
|
||||
yaml_str = OmegaConf.to_yaml(self._config)
|
||||
with open(newfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml_str)
|
||||
|
||||
@@ -18,76 +18,81 @@ def store(datadir) -> ModelConfigStore:
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelConfigStoreYAML(datadir / "configs" / "models.yaml")
|
||||
|
||||
|
||||
def example_config() -> TextualInversionConfig:
|
||||
return TextualInversionConfig(
|
||||
path='/tmp/pokemon.bin',
|
||||
name='old name',
|
||||
base_model='sd-1',
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
return TextualInversionConfig(
|
||||
path="/tmp/pokemon.bin",
|
||||
name="old name",
|
||||
base_model="sd-1",
|
||||
model_type="embedding",
|
||||
model_format="embedding_file",
|
||||
author="Anonymous",
|
||||
)
|
||||
|
||||
|
||||
def test_add(store: ModelConfigStore):
|
||||
raw = dict(path='/tmp/foo.ckpt',
|
||||
name='model1',
|
||||
base_model='sd-1',
|
||||
model_type='main',
|
||||
config='/tmp/foo.yaml',
|
||||
model_variant='normal',
|
||||
model_format='checkpoint'
|
||||
)
|
||||
store.add_model('key1', raw)
|
||||
config1 = store.get_model('key1')
|
||||
raw = dict(
|
||||
path="/tmp/foo.ckpt",
|
||||
name="model1",
|
||||
base_model="sd-1",
|
||||
model_type="main",
|
||||
config="/tmp/foo.yaml",
|
||||
model_variant="normal",
|
||||
model_format="checkpoint",
|
||||
)
|
||||
store.add_model("key1", raw)
|
||||
config1 = store.get_model("key1")
|
||||
assert config1 is not None
|
||||
raw['name'] = 'model2'
|
||||
raw['base_model'] = 'sd-2'
|
||||
raw['model_format'] = 'diffusers'
|
||||
raw.pop('config')
|
||||
store.add_model('key2', raw)
|
||||
config2 = store.get_model('key2')
|
||||
assert config1.name == 'model1'
|
||||
assert config2.name == 'model2'
|
||||
assert config1.base_model == 'sd-1'
|
||||
assert config2.base_model == 'sd-2'
|
||||
raw["name"] = "model2"
|
||||
raw["base_model"] = "sd-2"
|
||||
raw["model_format"] = "diffusers"
|
||||
raw.pop("config")
|
||||
store.add_model("key2", raw)
|
||||
config2 = store.get_model("key2")
|
||||
assert config1.name == "model1"
|
||||
assert config2.name == "model2"
|
||||
assert config1.base_model == "sd-1"
|
||||
assert config2.base_model == "sd-2"
|
||||
|
||||
|
||||
def test_update(store: ModelConfigStore):
|
||||
config = example_config()
|
||||
store.add_model('key1', config)
|
||||
config = store.get_model('key1')
|
||||
store.add_model("key1", config)
|
||||
config = store.get_model("key1")
|
||||
assert config.name == "old name"
|
||||
|
||||
config.name = 'new name'
|
||||
store.update_model('key1', config)
|
||||
new_config = store.get_model('key1')
|
||||
assert new_config.name == 'new name'
|
||||
config.name = "new name"
|
||||
store.update_model("key1", config)
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
try:
|
||||
store.update_model('unknown_key', config)
|
||||
store.update_model("unknown_key", config)
|
||||
assert False, "expected UnknownModelException"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_delete(store: ModelConfigStore):
|
||||
config = example_config()
|
||||
store.add_model('key1', config)
|
||||
config = store.get_model('key1')
|
||||
store.del_model('key1')
|
||||
store.add_model("key1", config)
|
||||
config = store.get_model("key1")
|
||||
store.del_model("key1")
|
||||
try:
|
||||
config = store.get_model('key1')
|
||||
config = store.get_model("key1")
|
||||
assert False, "expected fetch of deleted model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
try:
|
||||
store.del_model('unknown')
|
||||
store.del_model("unknown")
|
||||
assert False, "expected delete of unknown model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_exists(store: ModelConfigStore):
|
||||
config = example_config()
|
||||
store.add_model('key1', config)
|
||||
assert store.exists('key1')
|
||||
assert not store.exists('key2')
|
||||
|
||||
store.add_model("key1", config)
|
||||
assert store.exists("key1")
|
||||
assert not store.exists("key2")
|
||||
|
||||
Reference in New Issue
Block a user