diff --git a/ldm/generate.py b/ldm/generate.py index a437d3baf4..e2d4a40de7 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -802,6 +802,10 @@ class Generate: # the model cache does the loading and offloading cache = self.model_cache + if not cache.valid_model(model_name): + print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') + return self.model + cache.print_vram_usage() # have to get rid of all references to model in order diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 7b434941df..1999973ea8 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -41,15 +41,22 @@ class ModelCache(object): self.stack = [] # this is an LRU FIFO self.current_model = None + def valid_model(self, model_name:str)->bool: + ''' + Given a model name, returns True if it is a valid + identifier. + ''' + return model_name in self.config + def get_model(self, model_name:str): ''' Given a model named identified in models.yaml, return the model object. If in RAM will load into GPU VRAM. If on disk, will load from there. ''' - if model_name not in self.config: + if not self.valid_model(model_name): print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') - return None + return self.current_model if self.current_model != model_name: if model_name not in self.models: # make room for a new one