diff --git a/configs/models.yaml b/configs/models.yaml index 8dd792d75e..f3fde45d8f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -6,15 +6,15 @@ # and the width and height of the images it # was trained on. -laion400m: - config: configs/latent-diffusion/txt2img-1p4B-eval.yaml - weights: models/ldm/text2img-large/model.ckpt - description: Latent Diffusion LAION400M model - width: 256 - height: 256 stable-diffusion-1.4: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/model.ckpt description: Stable Diffusion inference model version 1.4 width: 512 height: 512 +stable-diffusion-1.5: + config: configs/stable-diffusion/v1-inference.yaml + weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt + description: Stable Diffusion inference model version 1.5 + width: 512 + height: 512 diff --git a/ldm/generate.py b/ldm/generate.py index b21787eb47..c8e495785c 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like) torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) +# this is fallback model in case no default is defined +FALLBACK_MODEL_NAME='stable-diffusion-1.4' + def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): def new_func(*args, **kw): @@ -147,7 +150,7 @@ class Generate: def __init__( self, - model = 'stable-diffusion-1.4', + model = None, conf = 'configs/models.yaml', embedding_path = None, sampler_name = 'k_lms', @@ -163,7 +166,6 @@ class Generate: free_gpu_mem=False, ): mconfig = OmegaConf.load(conf) - self.model_name = model self.height = None self.width = None self.model_cache = None @@ -210,6 +212,8 @@ class Generate: # model caching system for fast switching self.model_cache = ModelCache(mconfig,self.device,self.precision) + print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}') + self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME # for VRAM usage statistics self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None @@ -715,8 +719,7 @@ class Generate: model_data = self.model_cache.get_model(model_name) if model_data is None or len(model_data) == 0: - print(f'** Model switch failed **') - return self.model + return None self.model = model_data['model'] self.width = model_data['width'] diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 26920f28ea..22d845bbec 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -366,17 +366,16 @@ class Args(object): deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--weights') # deprecated model_group.add_argument( - '--conf', + '--config', '-c', - '-conf', + '-config', dest='conf', default='./configs/models.yaml', help='Path to configuration file for alternate models.', ) model_group.add_argument( '--model', - default='stable-diffusion-1.4', - help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', + help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', ) model_group.add_argument( '--png_compression','-z', @@ -529,7 +528,7 @@ class Args(object): formatter_class=ArgFormatter, description= """ - *Image generation:* + *Image generation* invoke> a fantastic alien landscape -W576 -H512 -s60 -n4 *postprocessing* @@ -544,6 +543,13 @@ class Args(object): !history lists all the commands issued during the current session. !NN retrieves the NNth command from the history + + *Model manipulation* + !models -- list models in configs/models.yaml + !switch -- switch to model named + !import_model path/to/weights/file.ckpt -- adds a model to your config + !edit_model -- edit a model's description + !del_model -- delete a model """ ) render_group = parser.add_argument_group('General rendering') diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5c6816e3c3..5e9e53cfb7 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -73,7 +73,8 @@ class ModelCache(object): except Exception as e: print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** restoring {self.current_model}') - return self.get_model(self.current_model) + self.get_model(self.current_model) + return None self.current_model = model_name self._push_newest_model(model_name) @@ -84,6 +85,26 @@ class ModelCache(object): 'hash': hash } + def default_model(self) -> str: + ''' + Returns the name of the default model, or None + if none is defined. + ''' + for model_name in self.config: + if self.config[model_name].get('default',False): + return model_name + return None + + def set_default_model(self,model_name:str): + ''' + Set the default model. The change will not take + effect until you call model_cache.commit() + ''' + assert model_name in self.models,f"unknown model '{model_name}'" + for model in self.models: + self.models[model].pop('default',None) + self.models[model_name]['default'] = True + def list_models(self) -> dict: ''' Return a dict of models in the format: @@ -121,12 +142,23 @@ class ModelCache(object): else: print(line) - def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: + def del_model(self, model_name:str) ->bool: + ''' + Delete the named model. + ''' + omega = self.config + del omega[model_name] + if model_name in self.stack: + self.stack.remove(model_name) + return True + + def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True: ''' Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and a YAML - string will be returned. + On a successful update, the config will be changed in memory and the + method will return True. Will fail with an assertion error if provided + attributes are incorrect or the model name is missing. ''' omega = self.config # check that all the required fields are present @@ -139,7 +171,9 @@ class ModelCache(object): config[field] = model_attributes[field] omega[model_name] = config - return OmegaConf.to_yaml(omega) + if clobber: + self._invalidate_cached_model(model_name) + return True def _check_memory(self): avail_memory = psutil.virtual_memory()[1] @@ -219,6 +253,36 @@ class ModelCache(object): if self._has_cuda(): torch.cuda.empty_cache() + def commit(self,config_file_path:str): + ''' + Write current configuration out to the indicated file. + ''' + yaml_str = OmegaConf.to_yaml(self.config) + tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp') + with open(tmpfile, 'w') as outfile: + outfile.write(self.preamble()) + outfile.write(yaml_str) + os.rename(tmpfile,config_file_path) + + def preamble(self): + ''' + Returns the preamble for the config file. + ''' + return '''# This file describes the alternative machine learning models +# available to the dream script. +# +# To add a new model, follow the examples below. Each +# model requires a model config file, a weights file, +# and the width and height of the images it +# was trained on. +''' + + def _invalidate_cached_model(self,model_name:str): + self.unload_model(model_name) + if model_name in self.stack: + self.stack.remove(model_name) + self.models.pop(model_name,None) + def _model_to_cpu(self,model): if self.device != 'cpu': model.cond_stage_model.device = 'cpu' diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index d7cff45bfa..7d87ede755 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -57,12 +57,13 @@ COMMANDS = ( '--png_compression','-z', '--text_mask','-tm', '!fix','!fetch','!replay','!history','!search','!clear', + '!models','!switch','!import_model','!edit_model','!del_model', '!mask', - '!models','!switch','!import_model','!edit_model' ) MODEL_COMMANDS = ( '!switch', '!edit_model', + '!del_model', ) WEIGHT_COMMANDS = ( '!import_model', @@ -218,9 +219,24 @@ class Completer(object): pydoc.pager('\n'.join(lines)) def set_line(self,line)->None: + ''' + Set the default string displayed in the next line of input. + ''' self.linebuffer = line readline.redisplay() + def add_model(self,model_name:str)->None: + ''' + add a model name to the completion list + ''' + self.models.append(model_name) + + def del_model(self,model_name:str)->None: + ''' + removes a model name from the completion list + ''' + self.models.remove(model_name) + def _seed_completions(self, text, state): m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) if m: diff --git a/scripts/invoke.py b/scripts/invoke.py index 84b3835579..3853c586e4 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple: completer.add_history(command) operation = None + elif command.startswith('!del'): + path = shlex.split(command) + if len(path) < 2: + print('** please provide the name of a model') + else: + del_config(path[1], gen, opt, completer) + completer.add_history(command) + operation = None + elif command.startswith('!fetch'): file_path = command.replace('!fetch','',1).strip() retrieve_dream_command(opt,file_path,completer) @@ -498,9 +507,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer): except: print('** Please enter a valid integer between 64 and 2048') - if write_config_file(opt.conf, gen, model_name, new_config): - gen.set_model(model_name) + make_default = input('Make this the default model? [n] ') in ('y','Y') + + if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default): + completer.add_model(model_name) +def del_config(model_name:str, gen, opt, completer): + current_model = gen.model_name + if model_name == current_model: + print("** Can't delete active model. !switch to another model first. **") + return + yaml_str = gen.model_cache.del_model(model_name) + + tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp') + with open(tmpfile, 'w') as outfile: + outfile.write(yaml_str) + os.rename(tmpfile,opt.conf) + print(f'** {model_name} deleted') + completer.del_model(model_name) + def edit_config(model_name:str, gen, opt, completer): config = gen.model_cache.config @@ -517,28 +542,41 @@ def edit_config(model_name:str, gen, opt, completer): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value + make_default = input('Make this the default model? [n] ') in ('y','Y') completer.complete_extensions(None) - - if write_config_file(opt.conf, gen, model_name, new_config, clobber=True): - gen.set_model(model_name) + write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default) + +def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False): + current_model = gen.model_name -def write_config_file(conf_path, gen, model_name, new_config, clobber=False): op = 'modify' if clobber else 'import' print('\n>> New configuration:') + if make_default: + new_config['default'] = True print(yaml.dump({model_name:new_config})) if input(f'OK to {op} [n]? ') not in ('y','Y'): return False try: + print('>> Verifying that new model loads...') yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) + assert gen.set_model(model_name) is not None, 'model failed to load' except AssertionError as e: - print(f'** configuration failed: {str(e)}') + print(f'** aborting **') + gen.model_cache.del_model(model_name) return False + + if make_default: + print('making this default') + gen.model_cache.set_default_model(model_name) + + gen.model_cache.commit(conf_path) - tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') - with open(tmpfile, 'w') as outfile: - outfile.write(yaml_str) - os.rename(tmpfile,conf_path) + do_switch = input(f'Keep model loaded? [y]') + if len(do_switch)==0 or do_switch[0] in ('y','Y'): + pass + else: + gen.set_model(current_model) return True def do_textmask(gen, opt, callback):