save name of last model to disk whenever model changes

- this allows invokeai to restore the last used model on startup, even
  after a crash or keyboard interrupt.
This commit is contained in:
Lincoln Stein
2023-04-02 15:45:43 -04:00
parent fd74f51384
commit d4d3441a52
2 changed files with 10 additions and 10 deletions

View File

@@ -22,6 +22,7 @@ import transformers
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
from PIL import Image, ImageOps
from pytorch_lightning import logging, seed_everything
@@ -991,8 +992,17 @@ class Generate:
self.model_name = model_name
self._set_sampler() # requires self.model_name to be set first
self._save_last_used_model(model_name)
return self.model
def _save_last_used_model(self,model_name:str):
"""
Save name of the last model used.
"""
model_file_path = Path(Globals.root,'.last_model')
with open(model_file_path,'w') as f:
f.write(model_name)
def load_huggingface_concepts(self, concepts: list[str]):
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)

View File

@@ -181,7 +181,6 @@ def main():
# web server loops forever
if opt.web or opt.gui:
invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan)
save_last_used_model(gen.model_name)
sys.exit(0)
if not infile:
@@ -502,7 +501,6 @@ def main_loop(gen, opt, completer):
print(
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
)
save_last_used_model(gen.model_name)
# TO DO: remove repetitive code and the awkward command.replace() trope
@@ -1300,14 +1298,6 @@ def retrieve_last_used_model()->str:
with open(model_file_path,'r') as f:
return f.readline()
def save_last_used_model(model_name:str):
"""
Save name of the last model used.
"""
model_file_path = Path(Globals.root,'.last_model')
with open(model_file_path,'w') as f:
f.write(model_name)
# This routine performs any patch-ups needed after installation
def run_patches():
install_missing_config_files()