Merge branch 'main' into api/add-trigger-string-retrieval

This commit is contained in:
Lincoln Stein
2023-02-17 15:53:57 -05:00
committed by GitHub
56 changed files with 1592 additions and 1112 deletions

View File

@@ -5,7 +5,9 @@ import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union
import click
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -24,6 +26,7 @@ from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import Completer, get_completer
from ldm.util import url_attachment_name
# global used in multiple functions (fix)
infile = None
@@ -78,7 +81,6 @@ def main():
import transformers # type: ignore
from ldm.generate import Generate
transformers.logging.set_verbosity_error()
import diffusers
@@ -623,10 +625,11 @@ def set_default_output_dir(opt: Args, completer: Completer):
def import_model(model_path: str, gen, opt, completer):
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a
diffusers model.
"""
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
(3) a huggingface repository id
"""
model.path = model_path.replace('\\','/') # windows
model_name = None
if model_path.startswith(("http:", "https:", "ftp:")):
@@ -669,7 +672,7 @@ def import_model(model_path: str, gen, opt, completer):
print("** model failed to load. Discarding configuration entry")
gen.model_manager.del_model(model_name)
return
if input("Make this the default model? [n] ").strip() in ("y", "Y"):
if click.confirm('Make this the default model?', default=False):
gen.model_manager.set_default_model(model_name)
gen.model_manager.commit(opt.conf)
@@ -677,9 +680,46 @@ def import_model(model_path: str, gen, opt, completer):
print(f">> {model_name} successfully installed")
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
'''
Does a mass import of all the checkpoint/safetensors on a path list
'''
model_names = list()
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
do_all = choice.startswith('a')
if do_all:
config_file = _ask_for_config_file(models[0], completer, plural=True)
manager = gen.model_manager
for model in sorted(models):
model_name = f'{model.stem}'
model_description = f'Imported model {model_name}'
if model_name in manager.model_names():
print(f'** {model_name} is already imported. Skipping.')
elif manager.import_ckpt_model(
model,
config = config_file,
model_name = model_name,
model_description = model_description,
commit_to_conf = opt.conf):
model_names.append(model_name)
print(f'>> Model {model_name} imported successfully')
else:
print(f'** Model {model} failed to import')
else:
for model in sorted(models):
if click.confirm(f'Import {model.stem} ?', default=True):
if model_name := import_ckpt_model(model, gen, opt, completer):
print(f'>> Model {model.stem} imported successfully')
model_names.append(model_name)
else:
printf('** Model {model} failed to import')
print()
return model_names
def import_diffuser_model(
path_or_repo: Union[Path, str], gen, _, completer
) -> Optional[str]:
path_or_repo = path_or_repo.replace('\\','/') # windows
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f"Imported model {default_name}"
@@ -690,10 +730,8 @@ def import_diffuser_model(
model_description=default_description,
)
vae = None
if input(
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
).strip() in ("y", "Y"):
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
if not manager.import_diffuser_model(
path_or_repo, model_name=model_name, vae=vae, description=model_description
@@ -702,13 +740,16 @@ def import_diffuser_model(
return None
return model_name
def import_ckpt_model(
path_or_url: Union[Path, str], gen, opt, completer
) -> Optional[str]:
path_or_url = path_or_url.replace('\\','/')
manager = gen.model_manager
default_name = Path(path_or_url).stem
is_a_url = str(path_or_url).startswith(('http:','https:'))
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
default_name = Path(base_name).stem
default_description = f"Imported model {default_name}"
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
@@ -758,10 +799,14 @@ def import_ckpt_model(
def _verify_load(model_name: str, gen) -> bool:
print(">> Verifying that new model loads...")
current_model = gen.model_name
if not gen.model_manager.get_model(model_name):
try:
if not gen.model_manager.get_model(model_name):
return False
except Exception as e:
print(f'** model failed to load: {str(e)}')
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
return False
do_switch = input("Keep model loaded? [y] ")
if len(do_switch) == 0 or do_switch[0] in ("y", "Y"):
if click.confirm('Keep model loaded?', default=True):
gen.set_model(model_name)
else:
print(">> Restoring previous model")
@@ -780,18 +825,44 @@ def _get_model_name_and_desc(
)
return model_name, model_description
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
default = '1'
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
default = '3'
choices={
'1': 'v1-inference.yaml',
'2': 'v2-inference-v.yaml',
'3': 'v1-inpainting-inference.yaml',
}
prompt = '''What type of models are these?:
[1] Models based on Stable Diffusion 1.X
[2] Models based on Stable Diffusion 2.X
[3] Inpainting models based on Stable Diffusion 1.X
[4] Something else''' if plural else '''What type of model is this?:
[1] A model based on Stable Diffusion 1.X
[2] A model based on Stable Diffusion 2.X
[3] An inpainting models based on Stable Diffusion 1.X
[4] Something else'''
print(prompt)
choice = input(f'Your choice: [{default}] ')
choice = choice.strip() or default
if config_file := choices.get(choice,None):
return Path('configs','stable-diffusion',config_file)
def _is_inpainting(model_name_or_path: str) -> bool:
if re.search("inpaint", model_name_or_path, flags=re.IGNORECASE):
return not input("Is this an inpainting model? [y] ").startswith(("n", "N"))
else:
return not input("Is this an inpainting model? [n] ").startswith(("y", "Y"))
# otherwise ask user to select
done = False
completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
while not done:
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
done = not config_path or os.path.exists(config_path)
return config_path
def optimize_model(model_name_or_path: str, gen, opt, completer):
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager
ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **")
@@ -806,16 +877,13 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
elif os.path.exists(model_name_or_path):
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
if not original_config_file:
return
ckpt_path = Path(model_name_or_path)
model_name, model_description = _get_model_name_and_desc(
manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}"
)
is_inpainting = _is_inpainting(model_name_or_path)
original_config_file = Path(
"configs",
"stable-diffusion",
"v1-inpainting-inference.yaml" if is_inpainting else "v1-inference.yaml",
)
else:
print(
f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file"
@@ -838,10 +906,8 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
vae = None
if input(
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
).strip() in ("y", "Y"):
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
new_config = gen.model_manager.convert_and_import(
ckpt_path,
@@ -856,11 +922,10 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
completer.update_models(gen.model_manager.list_models())
if input(f"Load optimized model {model_name}? [y] ").strip() not in ("n", "N"):
if click.confirm(f'Load optimized model {model_name}?', default=True):
gen.set_model(model_name)
response = input(f"Delete the original .ckpt file at ({ckpt_path} ? [n] ")
if response.startswith(("y", "Y")):
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted")
@@ -874,17 +939,11 @@ def del_config(model_name: str, gen, opt, completer):
print(f"** Unknown model {model_name}")
return
if (
input(f"Remove {model_name} from the list of models known to InvokeAI? [y] ")
.strip()
.startswith(("n", "N"))
):
if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
return
delete_completely = input(
"Completely remove the model file or directory from disk? [n] "
).startswith(("y", "Y"))
gen.model_manager.del_model(model_name, delete_files=delete_completely)
delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
gen.model_manager.del_model(model_name,delete_files=delete_completely)
gen.model_manager.commit(opt.conf)
print(f"** {model_name} deleted")
completer.update_models(gen.model_manager.list_models())
@@ -913,7 +972,7 @@ def edit_model(model_name: str, gen, opt, completer):
# this does the update
manager.add_model(new_name, info, True)
if input("Make this the default model? [n] ").startswith(("y", "Y")):
if click.confirm('Make this the default model?',default=False):
manager.set_default_model(new_name)
manager.commit(opt.conf)
completer.update_models(manager.list_models())
@@ -1288,10 +1347,7 @@ def report_model_error(opt: Namespace, e: Exception):
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
return
print("invokeai-configure is launching....\n")

View File

@@ -34,8 +34,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_progress_bar,
instantiate_from_config)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@@ -673,15 +673,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@@ -971,16 +974,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)