mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'diffusers_cross_attention_control_reimplementation' of github.com:damian0815/InvokeAI into diffusers_cross_attention_control_reimplementation
This commit is contained in:
@@ -485,7 +485,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
|
||||
|
||||
elif command.startswith('!optimize'):
|
||||
path = shlex.split(command)
|
||||
@@ -570,27 +570,26 @@ def import_model(model_path:str, gen, opt, completer):
|
||||
(3) a huggingface repository id
|
||||
'''
|
||||
model_name = None
|
||||
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
elif os.path.isdir(model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
||||
else:
|
||||
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
||||
|
||||
if not model_name:
|
||||
return
|
||||
|
||||
|
||||
if not _verify_load(model_name, gen):
|
||||
print('** model failed to load. Discarding configuration entry')
|
||||
gen.model_manager.del_model(model_name)
|
||||
return
|
||||
|
||||
if input('Make this the default model? [n] ') in ('y','Y'):
|
||||
if input('Make this the default model? [n] ').strip() in ('y','Y'):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
@@ -607,10 +606,14 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
vae = None
|
||||
if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo,
|
||||
model_name = model_name,
|
||||
vae = vae,
|
||||
description = model_description):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
@@ -627,18 +630,29 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
done = os.path.exists(config_file)
|
||||
|
||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||
vae = None
|
||||
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
||||
done = (not vae) or os.path.exists(vae)
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config = config_file,
|
||||
vae = vae,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf,
|
||||
@@ -690,7 +704,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
else:
|
||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||
return
|
||||
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
@@ -698,7 +712,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
if diffuser_path.exists():
|
||||
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||
return
|
||||
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffuser_path,
|
||||
@@ -710,7 +724,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if input(f'Load optimized model {model_name}? [y] ') not in ('n','N'):
|
||||
if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'):
|
||||
gen.set_model(model_name)
|
||||
|
||||
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
|
||||
@@ -726,7 +740,12 @@ def del_config(model_name:str, gen, opt, completer):
|
||||
if model_name not in gen.model_manager.config:
|
||||
print(f"** Unknown model {model_name}")
|
||||
return
|
||||
gen.model_manager.del_model(model_name)
|
||||
|
||||
if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')):
|
||||
return
|
||||
|
||||
delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y'))
|
||||
gen.model_manager.del_model(model_name,delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
@@ -747,7 +766,7 @@ def edit_model(model_name:str, gen, opt, completer):
|
||||
continue
|
||||
completer.set_line(info[attribute])
|
||||
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
||||
|
||||
|
||||
if new_name != model_name:
|
||||
manager.del_model(model_name)
|
||||
|
||||
@@ -1099,7 +1118,7 @@ def report_model_error(opt:Namespace, e:Exception):
|
||||
if yes_to_all is not None:
|
||||
sys.argv.append(yes_to_all)
|
||||
|
||||
import configure_invokeai
|
||||
import ldm.invoke.configure_invokeai as configure_invokeai
|
||||
configure_invokeai.main()
|
||||
print('** InvokeAI will now restart')
|
||||
sys.argv = previous_args
|
||||
|
||||
@@ -56,9 +56,11 @@ class CkptGenerator():
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
||||
@@ -21,7 +21,7 @@ import os
|
||||
import re
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
@@ -637,7 +637,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub'))
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -677,7 +677,8 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||
cache_dir = global_cache_dir('hub')
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -744,7 +745,8 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
cache_dir=global_cache_dir('hub')
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -795,6 +797,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
):
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
cache_dir = global_cache_dir('hub')
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@@ -904,7 +907,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer",cache_dir=global_cache_dir('diffusers'))
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@@ -917,8 +920,8 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
@@ -929,9 +932,9 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@@ -944,7 +947,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(
|
||||
|
||||
@@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
'''
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
|
||||
@@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
@@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
|
||||
876
ldm/invoke/configure_invokeai.py
Executable file
876
ldm/invoke/configure_invokeai.py
Executable file
@@ -0,0 +1,876 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
print('Loading Python libraries...\n')
|
||||
import argparse
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from urllib import request
|
||||
|
||||
import requests
|
||||
import transformers
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from getpass_asterisk import getpass_asterisk
|
||||
from huggingface_hub import HfFolder, hf_hub_url, login as hf_hub_login, whoami as hf_whoami
|
||||
from huggingface_hub.utils._errors import RevisionNotFoundError
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
import torch
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
try:
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
except ImportError:
|
||||
sys.path.append('.')
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
#--------------------------globals-----------------------
|
||||
Model_dir = 'models'
|
||||
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||
Dataset_path = './configs/INITIAL_MODELS.yaml'
|
||||
Default_config_file = './configs/models.yaml'
|
||||
SD_Configs = './configs/stable-diffusion'
|
||||
|
||||
assert os.path.exists(Dataset_path),"The configs directory cannot be found. Please run this script from within the invokeai runtime directory."
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
completer = generic_completer(['yes','no'])
|
||||
|
||||
Config_preamble = '''# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
'''
|
||||
|
||||
#---------------------------------------------
|
||||
def introduction():
|
||||
print(
|
||||
'''Welcome to InvokeAI. This script will help download the Stable Diffusion weight files
|
||||
and other large models that are needed for text to image generation. At any point you may interrupt
|
||||
this program and resume later.\n'''
|
||||
)
|
||||
|
||||
#--------------------------------------------
|
||||
def postscript(errors: None):
|
||||
if not any(errors):
|
||||
message='''
|
||||
** Model Installation Successful **
|
||||
|
||||
You're all set!
|
||||
|
||||
If you installed using one of the automated installation scripts,
|
||||
execute 'invoke.sh' (Linux/macOS) or 'invoke.bat' (Windows) to
|
||||
start InvokeAI.
|
||||
|
||||
If you installed manually, activate the 'invokeai' environment
|
||||
(e.g. 'conda activate invokeai'), then run one of the following
|
||||
commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
python scripts/invoke.py --web # (connect to http://localhost:9090)
|
||||
Command-line interface:
|
||||
python scripts/invoke.py
|
||||
|
||||
Have fun!
|
||||
'''
|
||||
|
||||
else:
|
||||
message=f"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
for err in errors:
|
||||
message += f"\t - {err}\n"
|
||||
message += "Please check the logs above and correct any issues."
|
||||
|
||||
print(message)
|
||||
|
||||
#---------------------------------------------
|
||||
def yes_or_no(prompt:str, default_yes=True):
|
||||
completer.set_options(['yes','no'])
|
||||
completer.complete_extensions(None) # turn off path-completion mode
|
||||
default = "y" if default_yes else 'n'
|
||||
response = input(f'{prompt} [{default}] ') or default
|
||||
if default_yes:
|
||||
return response[0] not in ('n','N')
|
||||
else:
|
||||
return response[0] in ('y','Y')
|
||||
|
||||
#---------------------------------------------
|
||||
def user_wants_to_download_weights()->str:
|
||||
'''
|
||||
Returns one of "skip", "recommended" or "customized"
|
||||
'''
|
||||
print('''You can download and configure the weights files manually or let this
|
||||
script do it for you. Manual installation is described at:
|
||||
|
||||
https://invoke-ai.github.io/InvokeAI/installation/020_INSTALL_MANUAL/
|
||||
|
||||
You may download the recommended models (about 10GB total), select a customized set, or
|
||||
completely skip this step.
|
||||
'''
|
||||
)
|
||||
completer.set_options(['recommended','customized','skip'])
|
||||
completer.complete_extensions(None) # turn off path-completion mode
|
||||
selection = None
|
||||
while selection is None:
|
||||
choice = input('Download <r>ecommended models, <a>ll models, <c>ustomized list, or <s>kip this step? [r]: ')
|
||||
if choice.startswith(('r','R')) or len(choice)==0:
|
||||
selection = 'recommended'
|
||||
elif choice.startswith(('c','C')):
|
||||
selection = 'customized'
|
||||
elif choice.startswith(('a','A')):
|
||||
selection = 'all'
|
||||
elif choice.startswith(('s','S')):
|
||||
selection = 'skip'
|
||||
return selection
|
||||
|
||||
#---------------------------------------------
|
||||
def select_datasets(action:str):
|
||||
done = False
|
||||
while not done:
|
||||
datasets = dict()
|
||||
dflt = None # the first model selected will be the default; TODO let user change
|
||||
counter = 1
|
||||
|
||||
if action == 'customized':
|
||||
print('''
|
||||
Choose the weight file(s) you wish to download. Before downloading you
|
||||
will be given the option to view and change your selections.
|
||||
'''
|
||||
)
|
||||
for ds in Datasets.keys():
|
||||
recommended = Datasets[ds].get('recommended',False)
|
||||
r_str = '(recommended)' if recommended else ''
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {r_str}')
|
||||
if yes_or_no(' Download?',default_yes=recommended):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
else:
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get('recommended',False):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
|
||||
print('The following weight files will be downloaded:')
|
||||
for ds in datasets:
|
||||
dflt = '*' if dflt is None else ''
|
||||
print(f' [{datasets[ds]}] {ds}{dflt}')
|
||||
print("*default")
|
||||
ok_to_download = yes_or_no('Ok to download?')
|
||||
if not ok_to_download:
|
||||
if yes_or_no('Change your selection?'):
|
||||
action = 'customized'
|
||||
pass
|
||||
else:
|
||||
done = True
|
||||
else:
|
||||
done = True
|
||||
return datasets if ok_to_download else None
|
||||
|
||||
#---------------------------------------------
|
||||
def recommended_datasets()->dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get('recommended',False):
|
||||
datasets[ds]=True
|
||||
return datasets
|
||||
|
||||
#---------------------------------------------
|
||||
def default_dataset()->dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get('default',False):
|
||||
datasets[ds]=True
|
||||
return datasets
|
||||
|
||||
#---------------------------------------------
|
||||
def all_datasets()->dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
datasets[ds]=True
|
||||
return datasets
|
||||
|
||||
#---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
"""
|
||||
|
||||
capture = io.StringIO()
|
||||
sys.stdout = capture
|
||||
try:
|
||||
hf_hub_login(token = access_token, add_to_git_credential=False)
|
||||
sys.stdout = sys.__stdout__
|
||||
except Exception as exc:
|
||||
sys.stdout = sys.__stdout__
|
||||
print(exc)
|
||||
raise exc
|
||||
|
||||
#-------------------------------Authenticate against Hugging Face
|
||||
def authenticate(yes_to_all=False):
|
||||
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
print('''
|
||||
By downloading the Stable Diffusion weight files from the official Hugging Face
|
||||
repository, you agree to have read and accepted the CreativeML Responsible AI License.
|
||||
The license terms are located here:
|
||||
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
|
||||
''')
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
|
||||
if not yes_to_all:
|
||||
accepted = False
|
||||
while not accepted:
|
||||
accepted = yes_or_no('Accept the above License terms?')
|
||||
if not accepted:
|
||||
print('Please accept the License or Ctrl+C to exit.')
|
||||
else:
|
||||
print('Thank you!')
|
||||
else:
|
||||
print("The program was started with a '--yes' flag, which indicates user's acceptance of the above License terms.")
|
||||
|
||||
# Authenticate to Huggingface using environment variables.
|
||||
# If successful, authentication will persist for either interactive or non-interactive use.
|
||||
# Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN.
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
print('Authenticating to Huggingface')
|
||||
hf_envvars = [ "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN" ]
|
||||
token_found = False
|
||||
if not (access_token := HfFolder.get_token()):
|
||||
print(f"Huggingface token not found in cache.")
|
||||
|
||||
for ev in hf_envvars:
|
||||
if (access_token := os.getenv(ev)):
|
||||
print(f"Token was found in the {ev} environment variable.... Logging in.")
|
||||
try:
|
||||
HfLogin(access_token)
|
||||
continue
|
||||
except ValueError:
|
||||
print(f"Login failed due to invalid token found in {ev}")
|
||||
else:
|
||||
print(f"Token was not found in the environment variable {ev}.")
|
||||
else:
|
||||
print(f"Huggingface token found in cache.")
|
||||
try:
|
||||
HfLogin(access_token)
|
||||
token_found = True
|
||||
except ValueError:
|
||||
print(f"Login failed due to invalid token found in cache")
|
||||
|
||||
if not (yes_to_all or token_found):
|
||||
print(''' You may optionally enter your Huggingface token now. InvokeAI
|
||||
*will* work without it but you will not be able to automatically
|
||||
download some of the Hugging Face style concepts. See
|
||||
https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept
|
||||
for more information.
|
||||
|
||||
Visit https://huggingface.co/settings/tokens to generate a token. (Sign up for an account if needed).
|
||||
|
||||
Paste the token below using Ctrl-V on macOS/Linux, or Ctrl-Shift-V or right-click on Windows.
|
||||
Alternatively press 'Enter' to skip this step and continue.
|
||||
You may re-run the configuration script again in the future if you do not wish to set the token right now.
|
||||
''')
|
||||
again = True
|
||||
while again:
|
||||
try:
|
||||
access_token = getpass_asterisk.getpass_asterisk(prompt="HF Token ⮞ ")
|
||||
HfLogin(access_token)
|
||||
access_token = HfFolder.get_token()
|
||||
again = False
|
||||
except ValueError:
|
||||
again = yes_or_no('Failed to log in to Huggingface. Would you like to try again?')
|
||||
if not again:
|
||||
print('\nRe-run the configuration script whenever you wish to set the token.')
|
||||
print('...Continuing...')
|
||||
except EOFError:
|
||||
# this happens if the user pressed Enter on the prompt without any input; assume this means they don't want to input a token
|
||||
# safety net needed against accidental "Enter"?
|
||||
print("None provided - continuing")
|
||||
again = False
|
||||
|
||||
elif access_token is None:
|
||||
print()
|
||||
print("HuggingFace login did not succeed. Some functionality may be limited; see https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept for more information")
|
||||
print()
|
||||
print(f"Re-run the configuration script without '--yes' to set the HuggingFace token interactively, or use one of the environment variables: {', '.join(hf_envvars)}")
|
||||
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
|
||||
return access_token
|
||||
|
||||
#---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root,Model_dir,Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path,'model.ckpt')):
|
||||
return
|
||||
new_name = Datasets['stable-diffusion-1.4']['file']
|
||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||
if rename:
|
||||
print(f'model.ckpt => {new_name}')
|
||||
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||
|
||||
#---------------------------------------------
|
||||
def download_weight_datasets(models:dict, access_token:str, precision:str='float32'):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
print(f'{mod}...',file=sys.stderr,end='')
|
||||
successful[mod] = _download_repo_or_file(Datasets[mod], access_token, precision=precision)
|
||||
return successful
|
||||
|
||||
def _download_repo_or_file(mconfig:DictConfig, access_token:str, precision:str='float32')->Path:
|
||||
path = None
|
||||
if mconfig['format'] == 'ckpt':
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if 'vae' in mconfig and 'repo_id' in mconfig['vae']:
|
||||
_download_diffusion_weights(mconfig['vae'], access_token, precision=precision)
|
||||
return path
|
||||
|
||||
def _download_ckpt_weights(mconfig:DictConfig, access_token:str)->Path:
|
||||
repo_id = mconfig['repo_id']
|
||||
filename = mconfig['file']
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token
|
||||
)
|
||||
|
||||
def _download_diffusion_weights(mconfig:DictConfig, access_token:str, precision:str='float32'):
|
||||
repo_id = mconfig['repo_id']
|
||||
model_class = StableDiffusionGeneratorPipeline if mconfig.get('format',None)=='diffusers' else AutoencoderKL
|
||||
extra_arg_list = [{'revision':'fp16'},{}] if precision=='float16' else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir='diffusers',
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith('fp16 is not a valid'):
|
||||
print(f'Could not fetch half-precision version of model {repo_id}; fetching full-precision instead')
|
||||
else:
|
||||
print(f'An unexpected error occurred while downloading the model: {e})')
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
#---------------------------------------------
|
||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
|
||||
open_mode = 'wb'
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header['Range'] = f'bytes={exist_size}-'
|
||||
open_mode = 'ab'
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get('content-length', 0))
|
||||
|
||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||
print(f'* {model_name}: complete file found. Skipping.')
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
|
||||
elif exist_size > 0:
|
||||
print(f'* {model_name}: partial file found. Resuming...')
|
||||
else:
|
||||
print(f'* {model_name}: Downloading...')
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total+exist_size,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
# -----------------------------------------------------------------------------------
|
||||
#---------------------------------------------
|
||||
def is_huggingface_authenticated():
|
||||
# huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError,
|
||||
# maybe other things, not all end-user-friendly.
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
response = hf_whoami()
|
||||
if response.get('id') is not None:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
||||
try:
|
||||
print(f'Installing {label} model file {model_url}...',end='',file=sys.stderr)
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
print('',file=sys.stderr)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
print('...downloaded successfully', file=sys.stderr)
|
||||
else:
|
||||
print('...exists', file=sys.stderr)
|
||||
except Exception:
|
||||
print('...download failed')
|
||||
print(f'Error downloading {label} model')
|
||||
print(traceback.format_exc())
|
||||
|
||||
#---------------------------------------------
|
||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||
config_file = opt.config_file or Default_config_file
|
||||
config_file = os.path.normpath(os.path.join(Globals.root,config_file))
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded,config_file)
|
||||
|
||||
try:
|
||||
if os.path.exists(config_file):
|
||||
print(f'** {config_file} exists. Renaming to {config_file}.orig')
|
||||
os.replace(config_file,f'{config_file}.orig')
|
||||
tmpfile = os.path.join(os.path.dirname(config_file),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(Config_preamble)
|
||||
outfile.write(yaml)
|
||||
os.replace(tmpfile,config_file)
|
||||
|
||||
except Exception as e:
|
||||
print(f'**Error creating config file {config_file}: {str(e)} **')
|
||||
return
|
||||
|
||||
print(f'Successfully created new configuration file {config_file}')
|
||||
|
||||
|
||||
#---------------------------------------------
|
||||
def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str:
|
||||
if os.path.exists(config_file):
|
||||
conf = OmegaConf.load(config_file)
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
# find the VAE file, if there is one
|
||||
vaes = {}
|
||||
default_selected = False
|
||||
|
||||
for model in successfully_downloaded:
|
||||
stanza = conf[model] if model in conf else { }
|
||||
mod = Datasets[model]
|
||||
stanza['description'] = mod['description']
|
||||
stanza['repo_id'] = mod['repo_id']
|
||||
stanza['format'] = mod['format']
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if 'width' in mod:
|
||||
stanza['width'] = mod['width']
|
||||
if 'height' in mod:
|
||||
stanza['height'] = mod['height']
|
||||
if 'file' in mod:
|
||||
stanza['weights'] = os.path.relpath(successfully_downloaded[model], start=Globals.root)
|
||||
stanza['config'] = os.path.normpath(os.path.join(SD_Configs,mod['config']))
|
||||
if 'vae' in mod:
|
||||
if 'file' in mod['vae']:
|
||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir, Weights_dir,mod['vae']['file']))
|
||||
else:
|
||||
stanza['vae'] = mod['vae']
|
||||
stanza.pop('default',None) # this will be set later
|
||||
|
||||
# BUG - the first stanza is always the default. User should select.
|
||||
if not default_selected:
|
||||
stanza['default'] = True
|
||||
default_selected = True
|
||||
conf[model] = stanza
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
#---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast
|
||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_from_hf(model_class:object, model_name:str, cache_subdir:Path=Path('hub'), **kwargs):
|
||||
print('',file=sys.stderr) # to prevent tqdm from overwriting
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
return path if model else None
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
print('Installing CLIP model (ignore deprecation errors)...',file=sys.stderr)
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
print('Tokenizer...',file=sys.stderr, end='')
|
||||
download_from_hf(CLIPTokenizer,version)
|
||||
print('Text model...',file=sys.stderr, end='')
|
||||
download_from_hf(CLIPTextModel,version)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_realesrgan():
|
||||
print('Installing models from RealESRGAN...',file=sys.stderr)
|
||||
model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
||||
model_dest = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
|
||||
download_with_progress_bar(model_url, model_dest, 'RealESRGAN')
|
||||
|
||||
def download_gfpgan():
|
||||
print('Installing GFPGAN models...',file=sys.stderr)
|
||||
for model in (
|
||||
[
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
|
||||
'./models/gfpgan/GFPGANv1.4.pth'
|
||||
],
|
||||
[
|
||||
'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth',
|
||||
'./models/gfpgan/weights/detection_Resnet50_Final.pth'
|
||||
],
|
||||
[
|
||||
'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth',
|
||||
'./models/gfpgan/weights/parsing_parsenet.pth'
|
||||
],
|
||||
):
|
||||
model_url,model_dest = model[0],os.path.join(Globals.root,model[1])
|
||||
download_with_progress_bar(model_url, model_dest, 'GFPGAN weights')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_codeformer():
|
||||
print('Installing CodeFormer model file...',file=sys.stderr)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth')
|
||||
download_with_progress_bar(model_url, model_dest, 'CodeFormer')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
print('Installing clipseg model for text-based masking...',end='', file=sys.stderr)
|
||||
import zipfile
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = os.path.join(Globals.root,'models/clipseg/clipseg_weights')
|
||||
weights_zip = 'models/clipseg/weights.zip'
|
||||
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
||||
dest = os.path.join(Globals.root,weights_zip)
|
||||
request.urlretrieve(model_url,dest)
|
||||
with zipfile.ZipFile(dest,'r') as zip:
|
||||
zip.extractall(os.path.join(Globals.root,'models/clipseg'))
|
||||
os.remove(dest)
|
||||
|
||||
from clipseg.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(Globals.root,'models/clipseg/clipseg_weights/rd64-uni-refined.pth'),
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#-------------------------------------
|
||||
def download_safety_checker():
|
||||
print('Installing model for NSFW content detection...',file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError:
|
||||
print('Error installing NSFW checker model:')
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
print('AutoFeatureExtractor...', end='',file=sys.stderr)
|
||||
download_from_hf(AutoFeatureExtractor,safety_model_id)
|
||||
print('StableDiffusionSafetyChecker...', end='',file=sys.stderr)
|
||||
download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#-------------------------------------
|
||||
def download_weights(opt:dict) -> Union[str, None]:
|
||||
|
||||
precision = 'float32' if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
|
||||
if opt.yes_to_all:
|
||||
models = default_dataset() if opt.default_only else recommended_datasets()
|
||||
access_token = authenticate(opt.yes_to_all)
|
||||
if len(models)>0:
|
||||
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
return
|
||||
|
||||
else:
|
||||
choice = user_wants_to_download_weights()
|
||||
|
||||
if choice == 'recommended':
|
||||
models = recommended_datasets()
|
||||
elif choice == 'all':
|
||||
models = all_datasets()
|
||||
elif choice == 'customized':
|
||||
models = select_datasets(choice)
|
||||
if models is None and yes_or_no('Quit?',default_yes=False):
|
||||
sys.exit(0)
|
||||
else: # 'skip'
|
||||
return
|
||||
|
||||
access_token = authenticate()
|
||||
if access_token is not None:
|
||||
HfFolder.save_token(access_token)
|
||||
|
||||
print('\n** DOWNLOADING WEIGHTS **')
|
||||
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
if len(successfully_downloaded) < len(models):
|
||||
return "some of the model weights downloads were not successful"
|
||||
|
||||
#-------------------------------------
|
||||
def get_root(root:str=None)->str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get('INVOKEAI_ROOT'):
|
||||
return os.environ.get('INVOKEAI_ROOT')
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
#-------------------------------------
|
||||
def select_root(root:str, yes_to_all:bool=False):
|
||||
default = root or os.path.expanduser('~/invokeai')
|
||||
if (yes_to_all):
|
||||
return default
|
||||
completer.set_default_dir(default)
|
||||
completer.complete_extensions(())
|
||||
completer.set_line(default)
|
||||
directory = input(f"Select a directory in which to install InvokeAI's models and configuration files [{default}]: ").strip(' \\')
|
||||
return directory or default
|
||||
|
||||
#-------------------------------------
|
||||
def select_outputs(root:str,yes_to_all:bool=False):
|
||||
default = os.path.normpath(os.path.join(root,'outputs'))
|
||||
if (yes_to_all):
|
||||
return default
|
||||
completer.set_default_dir(os.path.expanduser('~'))
|
||||
completer.complete_extensions(())
|
||||
completer.set_line(default)
|
||||
directory = input(f'Select the default directory for image outputs [{default}]: ').strip(' \\')
|
||||
return directory or default
|
||||
|
||||
#-------------------------------------
|
||||
def initialize_rootdir(root:str,yes_to_all:bool=False):
|
||||
assert os.path.exists('./configs'),'Run this script from within the InvokeAI source code directory, "InvokeAI" or the runtime directory "invokeai".'
|
||||
|
||||
print(f'** INITIALIZING INVOKEAI RUNTIME DIRECTORY **')
|
||||
root_selected = False
|
||||
while not root_selected:
|
||||
outputs = select_outputs(root,yes_to_all)
|
||||
outputs = outputs if os.path.isabs(outputs) else os.path.abspath(os.path.join(Globals.root,outputs))
|
||||
|
||||
print(f'\nInvokeAI image outputs will be placed into "{outputs}".')
|
||||
if not yes_to_all:
|
||||
root_selected = yes_or_no('Accept this location?')
|
||||
else:
|
||||
root_selected = True
|
||||
|
||||
print(f'\nYou may change the chosen output directory at any time by editing the --outdir options in "{Globals.initfile}",')
|
||||
print(f'You may also change the runtime directory by setting the environment variable INVOKEAI_ROOT.\n')
|
||||
|
||||
enable_safety_checker = True
|
||||
if not yes_to_all:
|
||||
print('The NSFW (not safe for work) checker blurs out images that potentially contain sexual imagery.')
|
||||
print('It can be selectively enabled at run time with --nsfw_checker, and disabled with --no-nsfw_checker.')
|
||||
print('The following option will set whether the checker is enabled by default. Like other options, you can')
|
||||
print(f'change this setting later by editing the file {Globals.initfile}.')
|
||||
print(f"This is NOT recommended for systems with less than 6G VRAM because of the checker's memory requirements.")
|
||||
enable_safety_checker = yes_or_no('Enable the NSFW checker by default?',enable_safety_checker)
|
||||
|
||||
safety_checker = '--nsfw_checker' if enable_safety_checker else '--no-nsfw_checker'
|
||||
|
||||
for name in ('models','configs','embeddings','text-inversion-data','text-inversion-training-data'):
|
||||
os.makedirs(os.path.join(root,name), exist_ok=True)
|
||||
for src in (['configs']):
|
||||
dest = os.path.join(root,src)
|
||||
if not os.path.samefile(src,dest):
|
||||
shutil.copytree(src,dest,dirs_exist_ok=True)
|
||||
os.makedirs(outputs, exist_ok=True)
|
||||
|
||||
init_file = os.path.join(Globals.root,Globals.initfile)
|
||||
|
||||
print(f'Creating the initialization file at "{init_file}".\n')
|
||||
with open(init_file,'w') as f:
|
||||
f.write(f'''# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running configure_invokeai.py again.
|
||||
|
||||
# the --outdir option controls the default location of image files.
|
||||
--outdir="{outputs}"
|
||||
|
||||
# generation arguments
|
||||
{safety_checker}
|
||||
|
||||
# You may place other frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
#
|
||||
''')
|
||||
|
||||
#-------------------------------------
|
||||
class ProgressBar():
|
||||
def __init__(self,model_name='file'):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar=tqdm(desc=self.name,
|
||||
initial=0,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
#-------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='InvokeAI model downloader')
|
||||
parser.add_argument('--interactive',
|
||||
dest='interactive',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='run in interactive mode (default) - DEPRECATED')
|
||||
parser.add_argument('--skip-sd-weights',
|
||||
dest='skip_sd_weights',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help='skip downloading the large Stable Diffusion weight files')
|
||||
parser.add_argument('--full-precision',
|
||||
dest='full_precision',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help='use 32-bit weights instead of faster 16-bit weights')
|
||||
parser.add_argument('--yes','-y',
|
||||
dest='yes_to_all',
|
||||
action='store_true',
|
||||
help='answer "yes" to all prompts')
|
||||
parser.add_argument('--default_only',
|
||||
action='store_true',
|
||||
help='when --yes specified, only install the default model')
|
||||
parser.add_argument('--config_file',
|
||||
'-c',
|
||||
dest='config_file',
|
||||
type=str,
|
||||
default='./configs/models.yaml',
|
||||
help='path to configuration file to create')
|
||||
parser.add_argument('--root_dir',
|
||||
dest='root',
|
||||
type=str,
|
||||
default=None,
|
||||
help='path to root of install directory')
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or '')
|
||||
|
||||
try:
|
||||
introduction()
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
if Globals.root == '' \
|
||||
or not os.path.exists(os.path.join(Globals.root,'invokeai.init')):
|
||||
initialize_rootdir(Globals.root,opt.yes_to_all)
|
||||
|
||||
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
|
||||
errors=set()
|
||||
|
||||
if not opt.interactive:
|
||||
print("WARNING: The --(no)-interactive argument is deprecated and will be removed. Use --skip-sd-weights.")
|
||||
opt.skip_sd_weights=True
|
||||
if opt.skip_sd_weights:
|
||||
print('** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **')
|
||||
else:
|
||||
print('** DOWNLOADING DIFFUSION WEIGHTS **')
|
||||
errors.add(download_weights(opt))
|
||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||
download_bert()
|
||||
download_clip()
|
||||
download_realesrgan()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
postscript(errors=errors)
|
||||
except KeyboardInterrupt:
|
||||
print('\nGoodbye! Come back soon.')
|
||||
except Exception as e:
|
||||
print(f'\nA problem occurred during initialization.\nThe error was: "{str(e)}"')
|
||||
print(traceback.format_exc())
|
||||
|
||||
#-------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -62,9 +62,11 @@ class Generator:
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
||||
@@ -29,6 +29,7 @@ else:
|
||||
|
||||
# Where to look for the initialization file
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
@@ -49,6 +50,9 @@ Globals.disable_xformers = False
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
||||
62
ldm/invoke/merge_diffusers.py
Normal file
62
ldm/invoke/merge_diffusers.py
Normal file
@@ -0,0 +1,62 @@
|
||||
'''
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
'''
|
||||
import os
|
||||
from typing import List
|
||||
from diffusers import DiffusionPipeline
|
||||
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
def merge_diffusion_models(models:List['str'],
|
||||
merged_model_name:str,
|
||||
alpha:float=0.5,
|
||||
interp:str=None,
|
||||
force:bool=False,
|
||||
**kwargs):
|
||||
'''
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
'''
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||
custom_pipeline='checkpoint_merger')
|
||||
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs)
|
||||
dump_path = global_models_dir() / 'merged_diffusers'
|
||||
os.makedirs(dump_path,exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained (
|
||||
dump_path,
|
||||
safe_serialization=1
|
||||
)
|
||||
model_manager.import_diffuser_model(
|
||||
dump_path,
|
||||
model_name = merged_model_name,
|
||||
description = f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||
if vae := model_manager.config[models[0]].get('vae',None):
|
||||
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||
model_manager.config[merged_model_name]['vae'] = vae
|
||||
|
||||
model_manager.commit(config_file)
|
||||
@@ -18,7 +18,9 @@ import traceback
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Union, Any
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from ldm.util import download_with_progress_bar
|
||||
|
||||
import torch
|
||||
@@ -35,9 +37,16 @@ from ldm.invoke.globals import Globals, global_models_dir, global_autoscan_dir,
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
'vae-ft-mse-840000-ema-pruned': 'stabilityai/sd-vae-ft-mse',
|
||||
}
|
||||
|
||||
class ModelManager(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
def __init__(self,
|
||||
config:OmegaConf,
|
||||
device_type:str='cpu',
|
||||
precision:str='float16',
|
||||
max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
@@ -143,7 +152,7 @@ class ModelManager(object):
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
'''
|
||||
info = self.model_info(model_name)
|
||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -226,7 +235,7 @@ class ModelManager(object):
|
||||
line = f'\033[1m{line}\033[0m'
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) -> None:
|
||||
def del_model(self, model_name:str, delete_files:bool=False) -> None:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
@@ -234,9 +243,25 @@ class ModelManager(object):
|
||||
if model_name not in omega:
|
||||
print(f'** Unknown model {model_name}')
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
repo_id = conf.get('repo_id',None)
|
||||
path = self._abs_path(conf.get('path',None))
|
||||
weights = self._abs_path(conf.get('weights',None))
|
||||
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f'** deleting file {weights}')
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f'** deleting directory {path}')
|
||||
rmtree(path,ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f'** deleting the cached model directory for {repo_id}')
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
|
||||
'''
|
||||
@@ -362,8 +387,14 @@ class ModelManager(object):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
vae_ckpt = None
|
||||
vae_dict = None
|
||||
if vae.endswith('.safetensors'):
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||
else:
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
@@ -407,7 +438,7 @@ class ModelManager(object):
|
||||
safety_checker=None,
|
||||
local_files_only=not Globals.internet_available
|
||||
)
|
||||
if 'vae' in mconfig:
|
||||
if 'vae' in mconfig and mconfig['vae'] is not None:
|
||||
vae = self._load_vae(mconfig['vae'])
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path,Path):
|
||||
@@ -513,11 +544,12 @@ class ModelManager(object):
|
||||
print('>> Model scanned ok!')
|
||||
|
||||
def import_diffuser_model(self,
|
||||
repo_or_path:Union[str,Path],
|
||||
model_name:str=None,
|
||||
description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
repo_or_path:Union[str,Path],
|
||||
model_name:str=None,
|
||||
description:str=None,
|
||||
vae:dict=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated diffuser model and returns True if successful.
|
||||
|
||||
@@ -533,10 +565,11 @@ class ModelManager(object):
|
||||
description = description or f'imported diffusers model {model_name}'
|
||||
new_config = dict(
|
||||
description=description,
|
||||
vae=vae,
|
||||
format='diffusers',
|
||||
)
|
||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||
new_config.update(path=repo_or_path)
|
||||
new_config.update(path=str(repo_or_path))
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
@@ -546,18 +579,22 @@ class ModelManager(object):
|
||||
return True
|
||||
|
||||
def import_ckpt_model(self,
|
||||
weights:Union[str,Path],
|
||||
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||
model_name:str=None,
|
||||
model_description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
weights:Union[str,Path],
|
||||
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||
vae:Union[str,Path]=None,
|
||||
model_name:str=None,
|
||||
model_description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||
or a http/https URL pointing to a remote model.
|
||||
|
||||
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
|
||||
as the VAE for this model.
|
||||
|
||||
"config" is the model config file to use with this ckpt file. It defaults to
|
||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||
|
||||
@@ -584,6 +621,8 @@ class ModelManager(object):
|
||||
width=512,
|
||||
height=512
|
||||
)
|
||||
if vae:
|
||||
new_config['vae'] = vae
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
@@ -623,7 +662,7 @@ class ModelManager(object):
|
||||
|
||||
def convert_and_import(self,
|
||||
ckpt_path:Path,
|
||||
diffuser_path:Path,
|
||||
diffusers_path:Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
commit_to_conf:Path=None,
|
||||
@@ -635,46 +674,56 @@ class ModelManager(object):
|
||||
new_config = None
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
import transformers
|
||||
if diffuser_path.exists():
|
||||
print(f'ERROR: The path {str(diffuser_path)} already exists. Please move or remove it and try again.')
|
||||
if diffusers_path.exists():
|
||||
print(f'ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again.')
|
||||
return
|
||||
|
||||
model_name = model_name or diffuser_path.name
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or 'Optimized version of {model_name}'
|
||||
print(f'>> {model_name}: optimizing (30-60s).')
|
||||
print(f'>> Optimizing {model_name} (30-60s)')
|
||||
try:
|
||||
verbosity =transformers.logging.get_verbosity()
|
||||
transformers.logging.set_verbosity_error()
|
||||
convert_ckpt_to_diffuser(ckpt_path, diffuser_path,extract_ema=True)
|
||||
convert_ckpt_to_diffuser(ckpt_path, diffusers_path,extract_ema=True)
|
||||
transformers.logging.set_verbosity(verbosity)
|
||||
print(f'>> Success. Optimized model is now located at {str(diffuser_path)}')
|
||||
print(f'>> Writing new config file entry for {model_name}...',end='')
|
||||
print(f'>> Success. Optimized model is now located at {str(diffusers_path)}')
|
||||
print(f'>> Writing new config file entry for {model_name}')
|
||||
new_config = dict(
|
||||
path=str(diffuser_path),
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
format='diffusers',
|
||||
)
|
||||
|
||||
# HACK (LS): in the event that the original entry is using a custom ckpt VAE, we try to
|
||||
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
|
||||
# I would prefer to do this differently: We load the ckpt model into memory, swap the
|
||||
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
|
||||
# VAE is built into the model. However, when I tried this I got obscure key errors.
|
||||
if model_name in self.config and (vae_ckpt_path := self.model_info(model_name)['vae']):
|
||||
vae_basename = Path(vae_ckpt_path).stem
|
||||
diffusers_vae = None
|
||||
if (diffusers_vae := VAE_TO_REPO_ID.get(vae_basename,None)):
|
||||
print(f'>> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version')
|
||||
new_config.update(
|
||||
vae = {'repo_id': diffusers_vae}
|
||||
)
|
||||
else:
|
||||
print(f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown')
|
||||
print(f'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config')
|
||||
new_config.update(
|
||||
vae = {'repo_id': 'stabilityai/sd-vae-ft-mse'}
|
||||
)
|
||||
|
||||
self.del_model(model_name)
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
print('>> Conversion succeeded')
|
||||
except Exception as e:
|
||||
print(f'** Conversion failed: {str(e)}')
|
||||
traceback.print_exc()
|
||||
|
||||
print('done.')
|
||||
return new_config
|
||||
|
||||
def del_config(self, model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f'>> Finding Models In: {search_folder}')
|
||||
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
|
||||
@@ -756,7 +805,6 @@ class ModelManager(object):
|
||||
|
||||
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
||||
print('** This is a quick one-time operation.')
|
||||
from shutil import move, rmtree
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
@@ -972,6 +1020,27 @@ class ModelManager(object):
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir('diffusers'))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
hashes_to_delete = set()
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id==repo_id:
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(f'** deletion of this model is expected to free {strategy.expected_freed_size_str}')
|
||||
strategy.execute()
|
||||
|
||||
@staticmethod
|
||||
def _abs_path(path:Union(str,Path))->Path:
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(Globals.root,path).resolve()
|
||||
|
||||
@staticmethod
|
||||
def _is_huggingface_hub_directory_present() -> bool:
|
||||
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -207,6 +206,12 @@ def parse_args():
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=Path,
|
||||
@@ -455,7 +460,8 @@ def do_textual_inversion_training(
|
||||
checkpointing_steps:int=500,
|
||||
resume_from_checkpoint:Path=None,
|
||||
enable_xformers_memory_efficient_attention:bool=False,
|
||||
root_dir:Path=None
|
||||
root_dir:Path=None,
|
||||
hub_model_id:str=None,
|
||||
):
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||
@@ -518,10 +524,10 @@ def do_textual_inversion_training(
|
||||
pretrained_model_name_or_path = model_conf.get('repo_id',None) or Path(model_conf.get('path'))
|
||||
assert pretrained_model_name_or_path, f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir('diffusers'))
|
||||
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,cache_dir=global_cache_dir('transformers'))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,**pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args)
|
||||
|
||||
@@ -631,7 +637,7 @@ def do_textual_inversion_training(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -670,6 +676,7 @@ def do_textual_inversion_training(
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_step = None
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if resume_from_checkpoint:
|
||||
@@ -680,15 +687,22 @@ def do_textual_inversion_training(
|
||||
dirs = os.listdir(output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
@@ -700,7 +714,7 @@ def do_textual_inversion_training(
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user