mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into api/add-trigger-string-retrieval
This commit is contained in:
@@ -5,7 +5,9 @@ import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import click
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@@ -24,6 +26,7 @@ from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import Completer, get_completer
|
||||
from ldm.util import url_attachment_name
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
@@ -78,7 +81,6 @@ def main():
|
||||
import transformers # type: ignore
|
||||
|
||||
from ldm.generate import Generate
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
@@ -623,10 +625,11 @@ def set_default_output_dir(opt: Args, completer: Completer):
|
||||
|
||||
|
||||
def import_model(model_path: str, gen, opt, completer):
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||
(3) a huggingface repository id; or (4) a local directory containing a
|
||||
diffusers model.
|
||||
"""
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
||||
(3) a huggingface repository id
|
||||
"""
|
||||
model.path = model_path.replace('\\','/') # windows
|
||||
model_name = None
|
||||
|
||||
if model_path.startswith(("http:", "https:", "ftp:")):
|
||||
@@ -669,7 +672,7 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
print("** model failed to load. Discarding configuration entry")
|
||||
gen.model_manager.del_model(model_name)
|
||||
return
|
||||
if input("Make this the default model? [n] ").strip() in ("y", "Y"):
|
||||
if click.confirm('Make this the default model?', default=False):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
@@ -677,9 +680,46 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
print(f">> {model_name} successfully installed")
|
||||
|
||||
|
||||
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
||||
'''
|
||||
Does a mass import of all the checkpoint/safetensors on a path list
|
||||
'''
|
||||
model_names = list()
|
||||
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
|
||||
do_all = choice.startswith('a')
|
||||
if do_all:
|
||||
config_file = _ask_for_config_file(models[0], completer, plural=True)
|
||||
manager = gen.model_manager
|
||||
for model in sorted(models):
|
||||
model_name = f'{model.stem}'
|
||||
model_description = f'Imported model {model_name}'
|
||||
if model_name in manager.model_names():
|
||||
print(f'** {model_name} is already imported. Skipping.')
|
||||
elif manager.import_ckpt_model(
|
||||
model,
|
||||
config = config_file,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf):
|
||||
model_names.append(model_name)
|
||||
print(f'>> Model {model_name} imported successfully')
|
||||
else:
|
||||
print(f'** Model {model} failed to import')
|
||||
else:
|
||||
for model in sorted(models):
|
||||
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||
if model_name := import_ckpt_model(model, gen, opt, completer):
|
||||
print(f'>> Model {model.stem} imported successfully')
|
||||
model_names.append(model_name)
|
||||
else:
|
||||
printf('** Model {model} failed to import')
|
||||
print()
|
||||
return model_names
|
||||
|
||||
def import_diffuser_model(
|
||||
path_or_repo: Union[Path, str], gen, _, completer
|
||||
) -> Optional[str]:
|
||||
path_or_repo = path_or_repo.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_repo).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
@@ -690,10 +730,8 @@ def import_diffuser_model(
|
||||
model_description=default_description,
|
||||
)
|
||||
vae = None
|
||||
if input(
|
||||
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
|
||||
).strip() in ("y", "Y"):
|
||||
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo, model_name=model_name, vae=vae, description=model_description
|
||||
@@ -702,13 +740,16 @@ def import_diffuser_model(
|
||||
return None
|
||||
return model_name
|
||||
|
||||
|
||||
def import_ckpt_model(
|
||||
path_or_url: Union[Path, str], gen, opt, completer
|
||||
) -> Optional[str]:
|
||||
path_or_url = path_or_url.replace('\\','/')
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_url).stem
|
||||
is_a_url = str(path_or_url).startswith(('http:','https:'))
|
||||
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
|
||||
default_name = Path(base_name).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
@@ -758,10 +799,14 @@ def import_ckpt_model(
|
||||
def _verify_load(model_name: str, gen) -> bool:
|
||||
print(">> Verifying that new model loads...")
|
||||
current_model = gen.model_name
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
try:
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'** model failed to load: {str(e)}')
|
||||
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
|
||||
return False
|
||||
do_switch = input("Keep model loaded? [y] ")
|
||||
if len(do_switch) == 0 or do_switch[0] in ("y", "Y"):
|
||||
if click.confirm('Keep model loaded?', default=True):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print(">> Restoring previous model")
|
||||
@@ -780,18 +825,44 @@ def _get_model_name_and_desc(
|
||||
)
|
||||
return model_name, model_description
|
||||
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
|
||||
default = '1'
|
||||
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
|
||||
default = '3'
|
||||
choices={
|
||||
'1': 'v1-inference.yaml',
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
|
||||
prompt = '''What type of models are these?:
|
||||
[1] Models based on Stable Diffusion 1.X
|
||||
[2] Models based on Stable Diffusion 2.X
|
||||
[3] Inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else''' if plural else '''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else'''
|
||||
print(prompt)
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
if config_file := choices.get(choice,None):
|
||||
return Path('configs','stable-diffusion',config_file)
|
||||
|
||||
def _is_inpainting(model_name_or_path: str) -> bool:
|
||||
if re.search("inpaint", model_name_or_path, flags=re.IGNORECASE):
|
||||
return not input("Is this an inpainting model? [y] ").startswith(("n", "N"))
|
||||
else:
|
||||
return not input("Is this an inpainting model? [n] ").startswith(("y", "Y"))
|
||||
# otherwise ask user to select
|
||||
done = False
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
|
||||
while not done:
|
||||
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
|
||||
done = not config_path or os.path.exists(config_path)
|
||||
return config_path
|
||||
|
||||
|
||||
def optimize_model(model_name_or_path: str, gen, opt, completer):
|
||||
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
original_config_file = None
|
||||
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
@@ -806,16 +877,13 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
elif os.path.exists(model_name_or_path):
|
||||
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
ckpt_path = Path(model_name_or_path)
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}"
|
||||
)
|
||||
is_inpainting = _is_inpainting(model_name_or_path)
|
||||
original_config_file = Path(
|
||||
"configs",
|
||||
"stable-diffusion",
|
||||
"v1-inpainting-inference.yaml" if is_inpainting else "v1-inference.yaml",
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file"
|
||||
@@ -838,10 +906,8 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
|
||||
return
|
||||
|
||||
vae = None
|
||||
if input(
|
||||
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
|
||||
).strip() in ("y", "Y"):
|
||||
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
@@ -856,11 +922,10 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if input(f"Load optimized model {model_name}? [y] ").strip() not in ("n", "N"):
|
||||
if click.confirm(f'Load optimized model {model_name}?', default=True):
|
||||
gen.set_model(model_name)
|
||||
|
||||
response = input(f"Delete the original .ckpt file at ({ckpt_path} ? [n] ")
|
||||
if response.startswith(("y", "Y")):
|
||||
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
|
||||
@@ -874,17 +939,11 @@ def del_config(model_name: str, gen, opt, completer):
|
||||
print(f"** Unknown model {model_name}")
|
||||
return
|
||||
|
||||
if (
|
||||
input(f"Remove {model_name} from the list of models known to InvokeAI? [y] ")
|
||||
.strip()
|
||||
.startswith(("n", "N"))
|
||||
):
|
||||
if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
|
||||
return
|
||||
|
||||
delete_completely = input(
|
||||
"Completely remove the model file or directory from disk? [n] "
|
||||
).startswith(("y", "Y"))
|
||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||
delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
|
||||
gen.model_manager.del_model(model_name,delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f"** {model_name} deleted")
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
@@ -913,7 +972,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
||||
# this does the update
|
||||
manager.add_model(new_name, info, True)
|
||||
|
||||
if input("Make this the default model? [n] ").startswith(("y", "Y")):
|
||||
if click.confirm('Make this the default model?',default=False):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
@@ -1288,10 +1347,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
|
||||
@@ -34,8 +34,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
global_models_dir)
|
||||
from ldm.util import (ask_user, download_with_progress_bar,
|
||||
instantiate_from_config)
|
||||
from ldm.util import (ask_user, download_with_resume,
|
||||
url_attachment_name, instantiate_from_config)
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
@@ -673,15 +673,18 @@ class ModelManager(object):
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return False
|
||||
if config_path is None or not config_path.exists():
|
||||
return False
|
||||
|
||||
model_name = model_name or Path(weights).stem
|
||||
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
model_description or f"imported stable diffusion weights file {model_name}"
|
||||
)
|
||||
@@ -971,16 +974,15 @@ class ModelManager(object):
|
||||
print("** Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
basename = os.path.basename(source)
|
||||
if not os.path.isabs(dest_directory):
|
||||
dest_directory = os.path.join(Globals.root, dest_directory)
|
||||
dest = os.path.join(dest_directory, basename)
|
||||
if download_with_progress_bar(str(source), Path(dest)):
|
||||
resolved_path = Path(dest)
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
|
||||
Reference in New Issue
Block a user