Merge branch 'main' into feat/import-with-vae

This commit is contained in:
Lincoln Stein
2023-01-24 09:07:41 -05:00
committed by GitHub
48 changed files with 2113 additions and 681 deletions

View File

@@ -146,7 +146,7 @@ class Generate:
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
free_gpu_mem: bool=False,
safety_checker:bool=False,
max_loaded_models:int=2,
# these are deprecated; if present they override values in the conf file
@@ -460,10 +460,13 @@ class Generate:
init_image = None
mask_image = None
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
try:
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass
try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
@@ -531,6 +534,7 @@ class Generate:
inpaint_height = inpaint_height,
inpaint_width = inpaint_width,
enable_image_debugging = enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
)
if init_color:

View File

@@ -56,9 +56,11 @@ class CkptGenerator():
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
attention_maps_callback = None,
free_gpu_mem: bool=False,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(

View File

@@ -62,9 +62,11 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
free_gpu_mem: bool=False,
**kwargs):
scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(

View File

@@ -29,6 +29,7 @@ else:
# Where to look for the initialization file
Globals.initfile = 'invokeai.init'
Globals.models_file = 'models.yaml'
Globals.models_dir = 'models'
Globals.config_dir = 'configs'
Globals.autoscan_dir = 'weights'
@@ -49,6 +50,9 @@ Globals.disable_xformers = False
# whether we are forcing full precision
Globals.full_precision = False
def global_config_file()->Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir)

View File

@@ -0,0 +1,62 @@
'''
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
'''
import os
from typing import List
from diffusers import DiffusionPipeline
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
from ldm.invoke.model_manager import ModelManager
from omegaconf import OmegaConf
def merge_diffusion_models(models:List['str'],
merged_model_name:str,
alpha:float=0.5,
interp:str=None,
force:bool=False,
**kwargs):
'''
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
'''
config_file = global_config_file()
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
custom_pipeline='checkpoint_merger')
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs)
dump_path = global_models_dir() / 'merged_diffusers'
os.makedirs(dump_path,exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained (
dump_path,
safe_serialization=1
)
model_manager.import_diffuser_model(
dump_path,
model_name = merged_model_name,
description = f'Merge of models {", ".join(models)}'
)
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
if vae := model_manager.config[models[0]].get('vae',None):
print(f'>> Using configured VAE assigned to {models[0]}')
model_manager.config[merged_model_name]['vae'] = vae
model_manager.commit(config_file)

View File

@@ -42,7 +42,11 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
}
class ModelManager(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
def __init__(self,
config:OmegaConf,
device_type:str='cpu',
precision:str='float16',
max_loaded_models=DEFAULT_MAX_MODELS):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
@@ -565,7 +569,7 @@ class ModelManager(object):
format='diffusers',
)
if isinstance(repo_or_path,Path) and repo_or_path.exists():
new_config.update(path=repo_or_path)
new_config.update(path=str(repo_or_path))
else:
new_config.update(repo_id=repo_or_path)