From 352155754127661dd6e46fd4aeb80de7059adc60 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 25 Dec 2022 09:28:46 +1300 Subject: [PATCH] Model Manager Backend Implementation --- backend/invoke_ai_web_server.py | 83 +++++++++++++++++++++++++++++++++ ldm/invoke/model_cache.py | 24 +++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 22d179604a..ded8618ab0 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -9,6 +9,7 @@ import io import base64 import os import json +import tkinter as tk from werkzeug.utils import secure_filename from flask import Flask, redirect, send_from_directory, request, make_response @@ -17,6 +18,7 @@ from PIL import Image, ImageOps from PIL.Image import Image as ImageType from uuid import uuid4 from threading import Event +from tkinter import filedialog from ldm.generate import Generate from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash @@ -297,6 +299,87 @@ class InvokeAIWebServer: config["infill_methods"] = infill_methods() socketio.emit("systemConfig", config) + @socketio.on('searchForModels') + def handle_search_models(): + try: + # Using tkinter to get the filepath because JS doesn't allow + root = tk.Tk() + root.iconify() # for macos + root.withdraw() + root.wm_attributes('-topmost', 1) + root.focus_force() + search_folder = filedialog.askdirectory(parent=root, title='Select Checkpoint Folder') + root.destroy() + + if not search_folder: + socketio.emit( + "foundModels", + {'search_folder': None, 'found_models': None}, + ) + else: + search_folder, found_models = self.generate.model_cache.search_models(search_folder) + socketio.emit( + "foundModels", + {'search_folder': search_folder, 'found_models': found_models}, + ) + except Exception as e: + self.socketio.emit("error", {"message": (str(e))}) + print("\n") + + traceback.print_exc() + print("\n") + + @socketio.on("addNewModel") + def handle_add_model(new_model_config: dict): + try: + model_name = new_model_config['name'] + del new_model_config['name'] + model_attributes = new_model_config + update = False + current_model_list = self.generate.model_cache.list_models() + if model_name in current_model_list: + update = True + + print(f">> Adding New Model: {model_name}") + + self.generate.model_cache.add_model( + model_name=model_name, model_attributes=model_attributes, clobber=True) + self.generate.model_cache.commit(opt.conf) + + new_model_list = self.generate.model_cache.list_models() + socketio.emit( + "newModelAdded", + {"new_model_name": model_name, + "model_list": new_model_list, 'update': update}, + ) + print(f">> New Model Added: {model_name}") + except Exception as e: + self.socketio.emit("error", {"message": (str(e))}) + print("\n") + + traceback.print_exc() + print("\n") + + @socketio.on("deleteModel") + def handle_delete_model(model_name: str): + try: + print(f">> Deleting Model: {model_name}") + self.generate.model_cache.del_model(model_name) + self.generate.model_cache.commit(opt.conf) + updated_model_list = self.generate.model_cache.list_models() + socketio.emit( + "modelDeleted", + {"deleted_model_name": model_name, + "model_list": updated_model_list}, + ) + print(f">> Model Deleted: {model_name}") + except Exception as e: + self.socketio.emit("error", {"message": (str(e))}) + print("\n") + + traceback.print_exc() + print("\n") + @socketio.on("requestModelChange") def handle_set_model(model_name: str): try: diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index e1f07f7c4c..2f670eac62 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -23,6 +23,7 @@ from omegaconf.errors import ConfigAttributeError from ldm.util import instantiate_from_config, ask_user from ldm.invoke.globals import Globals from picklescan.scanner import scan_file_path +from pathlib import Path DEFAULT_MAX_MODELS=2 @@ -135,8 +136,10 @@ class ModelCache(object): for name in self.config: try: description = self.config[name].description + weights = self.config[name].weights except ConfigAttributeError: description = '' + weights = '' if self.current_model == name: status = 'active' @@ -147,7 +150,8 @@ class ModelCache(object): models[name]={ 'status' : status, - 'description' : description + 'description' : description, + 'weights': weights } return models @@ -186,6 +190,8 @@ class ModelCache(object): config = omega[model_name] if model_name in omega else {} for field in model_attributes: + if field == 'weights': + field.replace('\\', '/') config[field] = model_attributes[field] omega[model_name] = config @@ -311,6 +317,22 @@ class ModelCache(object): sys.exit() else: print('>> Model Scanned. OK!!') + + def search_models(self, search_folder): + + print(f'>> Finding Models In: {search_folder}') + models_folder = Path(search_folder).glob('**/*.ckpt') + + files = [x for x in models_folder if x.is_file()] + + found_models = [] + for file in files: + found_models.append({ + 'name': file.stem, + 'location': str(file.resolve()).replace('\\', '/') + }) + + return search_folder, found_models def _make_cache_room(self) -> None: num_loaded_models = len(self.models)