add class docstring and blackify

This commit is contained in:
Lincoln Stein
2023-08-12 20:13:00 -04:00
parent 32958db6f6
commit b2894b5270
3 changed files with 84 additions and 47 deletions

View File

@@ -79,4 +79,3 @@ class ModelConfigStore(ABC):
:param key: Unique key for the model to be deleted
"""
pass

View File

@@ -1,5 +1,38 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Implementation of ModelConfigStore using a YAML file."""
"""
Implementation of ModelConfigStore using a YAML file.
Typical usage:
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML
store = ModelConfigStoreYAML("./configs/models.yaml")
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
)
# adding
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
# deleting
store.del_model('key1')
"""
import threading
import yaml
@@ -33,7 +66,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
def __init__(self, config_file: Path):
"""Initialize ModelConfigStore object with a .yaml file."""
super().__init__()
self._filename = Path(config_file)
self._filename = Path(config_file).absolute() # don't let chdir mess us up!
self._lock = threading.RLock()
if not self._filename.exists():
self._initialize_yaml()
@@ -51,7 +84,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
def _commit(self):
try:
self._lock.acquire()
newfile = Path(str(self._filename)+'.new')
newfile = Path(str(self._filename) + ".new")
yaml_str = OmegaConf.to_yaml(self._config)
with open(newfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml_str)

View File

@@ -18,76 +18,81 @@ def store(datadir) -> ModelConfigStore:
InvokeAIAppConfig.get_config(root=datadir)
return ModelConfigStoreYAML(datadir / "configs" / "models.yaml")
def example_config() -> TextualInversionConfig:
return TextualInversionConfig(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
return TextualInversionConfig(
path="/tmp/pokemon.bin",
name="old name",
base_model="sd-1",
model_type="embedding",
model_format="embedding_file",
author="Anonymous",
)
def test_add(store: ModelConfigStore):
raw = dict(path='/tmp/foo.ckpt',
name='model1',
base_model='sd-1',
model_type='main',
config='/tmp/foo.yaml',
model_variant='normal',
model_format='checkpoint'
)
store.add_model('key1', raw)
config1 = store.get_model('key1')
raw = dict(
path="/tmp/foo.ckpt",
name="model1",
base_model="sd-1",
model_type="main",
config="/tmp/foo.yaml",
model_variant="normal",
model_format="checkpoint",
)
store.add_model("key1", raw)
config1 = store.get_model("key1")
assert config1 is not None
raw['name'] = 'model2'
raw['base_model'] = 'sd-2'
raw['model_format'] = 'diffusers'
raw.pop('config')
store.add_model('key2', raw)
config2 = store.get_model('key2')
assert config1.name == 'model1'
assert config2.name == 'model2'
assert config1.base_model == 'sd-1'
assert config2.base_model == 'sd-2'
raw["name"] = "model2"
raw["base_model"] = "sd-2"
raw["model_format"] = "diffusers"
raw.pop("config")
store.add_model("key2", raw)
config2 = store.get_model("key2")
assert config1.name == "model1"
assert config2.name == "model2"
assert config1.base_model == "sd-1"
assert config2.base_model == "sd-2"
def test_update(store: ModelConfigStore):
config = example_config()
store.add_model('key1', config)
config = store.get_model('key1')
store.add_model("key1", config)
config = store.get_model("key1")
assert config.name == "old name"
config.name = 'new name'
store.update_model('key1', config)
new_config = store.get_model('key1')
assert new_config.name == 'new name'
config.name = "new name"
store.update_model("key1", config)
new_config = store.get_model("key1")
assert new_config.name == "new name"
try:
store.update_model('unknown_key', config)
store.update_model("unknown_key", config)
assert False, "expected UnknownModelException"
except UnknownModelException:
assert True
def test_delete(store: ModelConfigStore):
config = example_config()
store.add_model('key1', config)
config = store.get_model('key1')
store.del_model('key1')
store.add_model("key1", config)
config = store.get_model("key1")
store.del_model("key1")
try:
config = store.get_model('key1')
config = store.get_model("key1")
assert False, "expected fetch of deleted model to raise exception"
except UnknownModelException:
assert True
try:
store.del_model('unknown')
store.del_model("unknown")
assert False, "expected delete of unknown model to raise exception"
except UnknownModelException:
assert True
def test_exists(store: ModelConfigStore):
config = example_config()
store.add_model('key1', config)
assert store.exists('key1')
assert not store.exists('key2')
store.add_model("key1", config)
assert store.exists("key1")
assert not store.exists("key2")