mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
implemented tabbed model selection; not wired to backend yet
This commit is contained in:
@@ -6,6 +6,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List, Dict
|
||||
@@ -20,9 +21,8 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..model_management import ModelManager
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@@ -37,6 +37,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
# logger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
@@ -47,6 +50,11 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
install_models: List[str]
|
||||
remove_models: List[str]
|
||||
|
||||
def default_config_file():
|
||||
return config.model_conf_path
|
||||
@@ -61,11 +69,11 @@ def initial_models():
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
install_cn_models: List[str] = None,
|
||||
remove_cn_models: List[str] = None,
|
||||
cn_model_map: Dict[str,str] = None,
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
lora: ModelInstallList = None,
|
||||
ti: ModelInstallList = None,
|
||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
@@ -81,33 +89,30 @@ def install_requested_models(
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
install_controlnet_models(
|
||||
install_cn_models,
|
||||
short_name_map = cn_model_map,
|
||||
precision=precision,
|
||||
access_token=access_token,
|
||||
)
|
||||
delete_controlnet_models(remove_cn_models)
|
||||
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f"{model}...")
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("INSTALLING SELECTED STARTER MODELS")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
models=diffusers.install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
if len(successfully_downloaded) < len(diffusers.install_models):
|
||||
logger.warning("Some of the model downloads were not successful")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@@ -118,7 +123,7 @@ def install_requested_models(
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models) > 0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
@@ -190,10 +195,10 @@ def migrate_models_ckpt():
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print(
|
||||
logger.warning(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
print(f"model.ckpt => {new_name}")
|
||||
logger.warning(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
@@ -206,7 +211,7 @@ def download_weight_datasets(
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
logger.info(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
@@ -240,68 +245,6 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def install_controlnet_models(
|
||||
short_names: List[str],
|
||||
short_name_map: Dict[str,str],
|
||||
precision: str='float16',
|
||||
access_token: str = None,
|
||||
):
|
||||
'''
|
||||
Download list of controlnet models, using their HuggingFace
|
||||
repo_ids.
|
||||
'''
|
||||
dest_dir = config.controlnet_path
|
||||
if not dest_dir.exists():
|
||||
dest_dir.mkdir(parents=True,exist_ok=False)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in short_names:
|
||||
repo_id = short_name_map[directory_name]
|
||||
safe_name = directory_name.replace('/','--')
|
||||
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
print(f'Probing {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
# ---------------------------------------------
|
||||
def delete_controlnet_models(short_names: List[str]):
|
||||
for name in short_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = config.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
print(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, **kwargs
|
||||
@@ -340,7 +283,7 @@ def _download_diffusion_weights(
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
logger.error(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
@@ -374,17 +317,17 @@ def hf_download_with_resume(
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
print("** File not found")
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
print(f"** Warning: {model_name}: {resp.reason}")
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
@@ -399,7 +342,7 @@ def hf_download_with_resume(
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
@@ -424,8 +367,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
logger.warning(
|
||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
@@ -442,16 +385,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
logger.info("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
logger.info(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -518,8 +461,8 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
logger.warning(
|
||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
@@ -528,4 +471,4 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
logger.error(str(e))
|
||||
|
||||
@@ -11,6 +11,7 @@ import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
@@ -49,6 +50,10 @@ from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..install.model_install_backend import (
|
||||
Dataset_path,
|
||||
hf_download_with_resume,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@@ -1316,3 +1321,74 @@ class ModelManager(object):
|
||||
return (
|
||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||
)
|
||||
|
||||
def list_controlnet_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||
of model (defined in INITIAL_MODELS), and valule is True if installed'''
|
||||
|
||||
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
installed_models = {x: False for x in cn_models.keys()}
|
||||
|
||||
cn_dir = self.globals.controlnet_path
|
||||
installed_cn_models = dict()
|
||||
for root, dirs, files in os.walk(cn_dir):
|
||||
for name in dirs:
|
||||
if Path(root, name, '.download_complete').exists():
|
||||
installed_models.update({name.replace('--','/'): True})
|
||||
return installed_models
|
||||
|
||||
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
||||
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
||||
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
dest_dir = self.globals.controlnet_path
|
||||
dest_dir.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if self.precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in model_names:
|
||||
repo_id = short_names.get(directory_name) or directory_name
|
||||
safe_name = directory_name.replace('/','--')
|
||||
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
def delete_controlnet_models(self, model_names: List[str]):
|
||||
'''Remove the list of controlnet models'''
|
||||
for name in model_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = self.globals.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
self.logger.info(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user