Merge branch 'model-switching' into development

This commit is contained in:
Lincoln Stein
2022-10-21 21:27:59 -04:00
6 changed files with 159 additions and 32 deletions

View File

@@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
@@ -147,7 +150,7 @@ class Generate:
def __init__(
self,
model = 'stable-diffusion-1.4',
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
@@ -163,7 +166,6 @@ class Generate:
free_gpu_mem=False,
):
mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None
self.width = None
self.model_cache = None
@@ -210,6 +212,8 @@ class Generate:
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}')
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@@ -715,8 +719,7 @@ class Generate:
model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **')
return self.model
return None
self.model = model_data['model']
self.width = model_data['width']

View File

@@ -366,17 +366,16 @@ class Args(object):
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument(
'--conf',
'--config',
'-c',
'-conf',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--png_compression','-z',
@@ -529,7 +528,7 @@ class Args(object):
formatter_class=ArgFormatter,
description=
"""
*Image generation:*
*Image generation*
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
*postprocessing*
@@ -544,6 +543,13 @@ class Args(object):
!history lists all the commands issued during the current session.
!NN retrieves the NNth command from the history
*Model manipulation*
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model path/to/weights/file.ckpt -- adds a model to your config
!edit_model <model_name> -- edit a model's description
!del_model <model_name> -- delete a model
"""
)
render_group = parser.add_argument_group('General rendering')

View File

@@ -73,7 +73,8 @@ class ModelCache(object):
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(f'** restoring {self.current_model}')
return self.get_model(self.current_model)
self.get_model(self.current_model)
return None
self.current_model = model_name
self._push_newest_model(model_name)
@@ -84,6 +85,26 @@ class ModelCache(object):
'hash': hash
}
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict:
'''
Return a dict of models in the format:
@@ -121,12 +142,23 @@ class ModelCache(object):
else:
print(line)
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
def del_model(self, model_name:str) ->bool:
'''
Delete the named model.
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and a YAML
string will be returned.
On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
'''
omega = self.config
# check that all the required fields are present
@@ -139,7 +171,9 @@ class ModelCache(object):
config[field] = model_attributes[field]
omega[model_name] = config
return OmegaConf.to_yaml(omega)
if clobber:
self._invalidate_cached_model(model_name)
return True
def _check_memory(self):
avail_memory = psutil.virtual_memory()[1]
@@ -219,6 +253,36 @@ class ModelCache(object):
if self._has_cuda():
torch.cuda.empty_cache()
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.unload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model):
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'

View File

@@ -57,12 +57,13 @@ COMMANDS = (
'--png_compression','-z',
'--text_mask','-tm',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model','!del_model',
'!mask',
'!models','!switch','!import_model','!edit_model'
)
MODEL_COMMANDS = (
'!switch',
'!edit_model',
'!del_model',
)
WEIGHT_COMMANDS = (
'!import_model',
@@ -218,9 +219,24 @@ class Completer(object):
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line
readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m: