From a705a5a0aa3f40f9cd7f7bd330d1c1594d207bd8 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 15 Oct 2022 15:46:29 -0400 Subject: [PATCH 1/2] enhance support for model switching and editing - Error checks for invalid model - Add !del_model command to invoke.py - Add del_model() method to model_cache - Autocompleter kept in sync with model addition/subtraction. --- ldm/generate.py | 3 +-- ldm/invoke/args.py | 9 +++++++- ldm/invoke/model_cache.py | 13 +++++++++++- ldm/invoke/readline.py | 18 +++++++++++++++- scripts/invoke.py | 43 +++++++++++++++++++++++++++++++++------ 5 files changed, 75 insertions(+), 11 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index fe2dffb1d7..0f543e97ec 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -683,8 +683,7 @@ class Generate: model_data = self.model_cache.get_model(model_name) if model_data is None or len(model_data) == 0: - print(f'** Model switch failed **') - return self.model + return None self.model = model_data['model'] self.width = model_data['width'] diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 4c6f69fe53..2b14129a84 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -519,7 +519,7 @@ class Args(object): formatter_class=ArgFormatter, description= """ - *Image generation:* + *Image generation* invoke> a fantastic alien landscape -W576 -H512 -s60 -n4 *postprocessing* @@ -534,6 +534,13 @@ class Args(object): !history lists all the commands issued during the current session. !NN retrieves the NNth command from the history + + *Model manipulation* + !models -- list models in configs/models.yaml + !switch -- switch to model named + !import_model path/to/weights/file.ckpt -- adds a model to your config + !edit_model -- edit a model's description + !del_model -- delete a model """ ) render_group = parser.add_argument_group('General rendering') diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5c6816e3c3..eecec5ff9d 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -73,7 +73,8 @@ class ModelCache(object): except Exception as e: print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** restoring {self.current_model}') - return self.get_model(self.current_model) + self.get_model(self.current_model) + return None self.current_model = model_name self._push_newest_model(model_name) @@ -121,6 +122,16 @@ class ModelCache(object): else: print(line) + def del_model(self, model_name:str) ->str: + ''' + Delete the named model and return the YAML + ''' + omega = self.config + del omega[model_name] + if model_name in self.stack: + self.stack.remove(model_name) + return OmegaConf.to_yaml(omega) + def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: ''' Update the named model with a dictionary of attributes. Will fail with an diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index e6ba39e793..2292f11b59 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -53,11 +53,12 @@ COMMANDS = ( '--log_tokenization','-t', '--hires_fix', '!fix','!fetch','!history','!search','!clear', - '!models','!switch','!import_model','!edit_model' + '!models','!switch','!import_model','!edit_model','!del_model', ) MODEL_COMMANDS = ( '!switch', '!edit_model', + '!del_model', ) WEIGHT_COMMANDS = ( '!import_model', @@ -205,9 +206,24 @@ class Completer(object): pydoc.pager('\n'.join(lines)) def set_line(self,line)->None: + ''' + Set the default string displayed in the next line of input. + ''' self.linebuffer = line readline.redisplay() + def add_model(self,model_name:str)->None: + ''' + add a model name to the completion list + ''' + self.models.append(model_name) + + def del_model(self,model_name:str)->None: + ''' + removes a model name from the completion list + ''' + self.models.remove(model_name) + def _seed_completions(self, text, state): m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) if m: diff --git a/scripts/invoke.py b/scripts/invoke.py index fbee218b78..7b9c574913 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -381,6 +381,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple: completer.add_history(command) operation = None + elif command.startswith('!del'): + path = shlex.split(command) + if len(path) < 2: + print('** please provide the name of a model') + else: + del_config(path[1], gen, opt, completer) + completer.add_history(command) + operation = None + elif command.startswith('!fetch'): file_path = command.replace('!fetch ','',1) retrieve_dream_command(opt,file_path,completer) @@ -446,10 +455,23 @@ def add_weights_to_config(model_path:str, gen, opt, completer): done = True except: print('** Please enter a valid integer between 64 and 2048') - if write_config_file(opt.conf, gen, model_name, new_config): - gen.set_model(model_name) + completer.add_model(model_name) +def del_config(model_name:str, gen, opt, completer): + current_model = gen.model_name + if model_name == current_model: + print("** Can't delete active model. !switch to another model first. **") + return + yaml_str = gen.model_cache.del_model(model_name) + + tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp') + with open(tmpfile, 'w') as outfile: + outfile.write(yaml_str) + os.rename(tmpfile,opt.conf) + print(f'** {model_name} deleted') + completer.del_model(model_name) + def edit_config(model_name:str, gen, opt, completer): config = gen.model_cache.config @@ -467,11 +489,11 @@ def edit_config(model_name:str, gen, opt, completer): new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value completer.complete_extensions(None) - - if write_config_file(opt.conf, gen, model_name, new_config, clobber=True): - gen.set_model(model_name) + write_config_file(opt.conf, gen, model_name, new_config, clobber=True) def write_config_file(conf_path, gen, model_name, new_config, clobber=False): + current_model = gen.model_name + op = 'modify' if clobber else 'import' print('\n>> New configuration:') print(yaml.dump({model_name:new_config})) @@ -479,15 +501,24 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False): return False try: + print('>> Verifying that new model loads...') yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) + assert gen.set_model(model_name) is not None, 'model failed to load' except AssertionError as e: - print(f'** configuration failed: {str(e)}') + print(f'** aborting **') + gen.model_cache.del_model(model_name) return False tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') with open(tmpfile, 'w') as outfile: outfile.write(yaml_str) os.rename(tmpfile,conf_path) + + do_switch = input(f'Keep model loaded? [y]') + if len(do_switch)==0 or do_switch[0] in ('y','Y'): + pass + else: + gen.set_model(current_model) return True def do_postprocess (gen, opt, callback): From 83e6ab08aae620745f8bddb587b6799a1d70b7ea Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 21 Oct 2022 00:28:54 -0400 Subject: [PATCH 2/2] further improvements to model loading - code for committing config changes to models.yaml now in module rather than in invoke script - model marked "default" is now loaded if model not specified on command line - uncache changed models when edited, so that they reload properly - removed liaon from models.yaml and added stable-diffusion-1.5 --- configs/models.yaml | 12 +++---- ldm/generate.py | 8 +++-- ldm/invoke/args.py | 7 ++-- ldm/invoke/model_cache.py | 67 +++++++++++++++++++++++++++++++++++---- scripts/invoke.py | 24 +++++++++----- 5 files changed, 91 insertions(+), 27 deletions(-) diff --git a/configs/models.yaml b/configs/models.yaml index 8dd792d75e..f3fde45d8f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -6,15 +6,15 @@ # and the width and height of the images it # was trained on. -laion400m: - config: configs/latent-diffusion/txt2img-1p4B-eval.yaml - weights: models/ldm/text2img-large/model.ckpt - description: Latent Diffusion LAION400M model - width: 256 - height: 256 stable-diffusion-1.4: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/model.ckpt description: Stable Diffusion inference model version 1.4 width: 512 height: 512 +stable-diffusion-1.5: + config: configs/stable-diffusion/v1-inference.yaml + weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt + description: Stable Diffusion inference model version 1.5 + width: 512 + height: 512 diff --git a/ldm/generate.py b/ldm/generate.py index 0f543e97ec..5f3a6fd4b5 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,6 +35,9 @@ from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c from ldm.invoke.model_cache import ModelCache +# this is fallback model in case no default is defined +FALLBACK_MODEL_NAME='stable-diffusion-1.4' + def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): def new_func(*args, **kw): @@ -127,7 +130,7 @@ class Generate: def __init__( self, - model = 'stable-diffusion-1.4', + model = None, conf = 'configs/models.yaml', embedding_path = None, sampler_name = 'k_lms', @@ -143,7 +146,6 @@ class Generate: free_gpu_mem=False, ): mconfig = OmegaConf.load(conf) - self.model_name = model self.height = None self.width = None self.model_cache = None @@ -188,6 +190,8 @@ class Generate: # model caching system for fast switching self.model_cache = ModelCache(mconfig,self.device,self.precision) + print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}') + self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME # for VRAM usage statistics self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 2b14129a84..c997f45da2 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -364,17 +364,16 @@ class Args(object): deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--weights') # deprecated model_group.add_argument( - '--conf', + '--config', '-c', - '-conf', + '-config', dest='conf', default='./configs/models.yaml', help='Path to configuration file for alternate models.', ) model_group.add_argument( '--model', - default='stable-diffusion-1.4', - help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', + help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', ) model_group.add_argument( '--sampler', diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index eecec5ff9d..5e9e53cfb7 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -85,6 +85,26 @@ class ModelCache(object): 'hash': hash } + def default_model(self) -> str: + ''' + Returns the name of the default model, or None + if none is defined. + ''' + for model_name in self.config: + if self.config[model_name].get('default',False): + return model_name + return None + + def set_default_model(self,model_name:str): + ''' + Set the default model. The change will not take + effect until you call model_cache.commit() + ''' + assert model_name in self.models,f"unknown model '{model_name}'" + for model in self.models: + self.models[model].pop('default',None) + self.models[model_name]['default'] = True + def list_models(self) -> dict: ''' Return a dict of models in the format: @@ -122,22 +142,23 @@ class ModelCache(object): else: print(line) - def del_model(self, model_name:str) ->str: + def del_model(self, model_name:str) ->bool: ''' - Delete the named model and return the YAML + Delete the named model. ''' omega = self.config del omega[model_name] if model_name in self.stack: self.stack.remove(model_name) - return OmegaConf.to_yaml(omega) + return True - def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: + def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True: ''' Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and a YAML - string will be returned. + On a successful update, the config will be changed in memory and the + method will return True. Will fail with an assertion error if provided + attributes are incorrect or the model name is missing. ''' omega = self.config # check that all the required fields are present @@ -150,7 +171,9 @@ class ModelCache(object): config[field] = model_attributes[field] omega[model_name] = config - return OmegaConf.to_yaml(omega) + if clobber: + self._invalidate_cached_model(model_name) + return True def _check_memory(self): avail_memory = psutil.virtual_memory()[1] @@ -230,6 +253,36 @@ class ModelCache(object): if self._has_cuda(): torch.cuda.empty_cache() + def commit(self,config_file_path:str): + ''' + Write current configuration out to the indicated file. + ''' + yaml_str = OmegaConf.to_yaml(self.config) + tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp') + with open(tmpfile, 'w') as outfile: + outfile.write(self.preamble()) + outfile.write(yaml_str) + os.rename(tmpfile,config_file_path) + + def preamble(self): + ''' + Returns the preamble for the config file. + ''' + return '''# This file describes the alternative machine learning models +# available to the dream script. +# +# To add a new model, follow the examples below. Each +# model requires a model config file, a weights file, +# and the width and height of the images it +# was trained on. +''' + + def _invalidate_cached_model(self,model_name:str): + self.unload_model(model_name) + if model_name in self.stack: + self.stack.remove(model_name) + self.models.pop(model_name,None) + def _model_to_cpu(self,model): if self.device != 'cpu': model.cond_stage_model.device = 'cpu' diff --git a/scripts/invoke.py b/scripts/invoke.py index 7b9c574913..8aedfb87a5 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -341,6 +341,7 @@ def main_loop(gen, opt, infile): print('goodbye!') + # to do: this is ugly, fix def do_command(command:str, gen, opt:Args, completer) -> tuple: operation = 'generate' # default operation, alternative is 'postprocess' @@ -455,7 +456,10 @@ def add_weights_to_config(model_path:str, gen, opt, completer): done = True except: print('** Please enter a valid integer between 64 and 2048') - if write_config_file(opt.conf, gen, model_name, new_config): + + make_default = input('Make this the default model? [n] ') in ('y','Y') + + if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default): completer.add_model(model_name) def del_config(model_name:str, gen, opt, completer): @@ -488,14 +492,17 @@ def edit_config(model_name:str, gen, opt, completer): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value + make_default = input('Make this the default model? [n] ') in ('y','Y') completer.complete_extensions(None) - write_config_file(opt.conf, gen, model_name, new_config, clobber=True) + write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default) -def write_config_file(conf_path, gen, model_name, new_config, clobber=False): +def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False): current_model = gen.model_name op = 'modify' if clobber else 'import' print('\n>> New configuration:') + if make_default: + new_config['default'] = True print(yaml.dump({model_name:new_config})) if input(f'OK to {op} [n]? ') not in ('y','Y'): return False @@ -508,12 +515,13 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False): print(f'** aborting **') gen.model_cache.del_model(model_name) return False - - tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') - with open(tmpfile, 'w') as outfile: - outfile.write(yaml_str) - os.rename(tmpfile,conf_path) + if make_default: + print('making this default') + gen.model_cache.set_default_model(model_name) + + gen.model_cache.commit(conf_path) + do_switch = input(f'Keep model loaded? [y]') if len(do_switch)==0 or do_switch[0] in ('y','Y'): pass