detect lora files with .pt suffix

This commit is contained in:
Lincoln Stein
2023-04-01 17:25:54 -04:00
parent 605ceb2e95
commit d3b63ca0fe

View File

@@ -25,11 +25,20 @@ from invokeai.backend.modules.parameters import parameters_to_command
import invokeai.frontend.dist as frontend
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
get_tokenizer
from ldm.invoke.conditioning import (
get_tokens_for_prompt_object,
get_prompt_structure,
split_weighted_subprompts,
get_tokenizer,
)
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir, global_models_dir, global_lora_models_dir
from ldm.invoke.globals import (
Globals,
global_converted_ckpts_dir,
global_models_dir,
global_lora_models_dir,
)
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from compel.prompt_parser import Blend
from ldm.invoke.merge_diffusers import merge_diffusion_models
@@ -192,8 +201,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(
file_path), self.thumbnail_image_path
pil_image, os.path.basename(file_path), self.thumbnail_image_path
)
response = {
@@ -223,7 +231,7 @@ class InvokeAIWebServer:
server="flask_socketio",
width=1600,
height=1000,
port=self.port
port=self.port,
).run()
except KeyboardInterrupt:
import sys
@@ -264,16 +272,14 @@ class InvokeAIWebServer:
# location for "finished" images
self.result_path = args.outdir
# temporary path for intermediates
self.intermediate_path = os.path.join(
self.result_path, "intermediates/")
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
# path for user-uploaded init images and masks
self.init_image_path = os.path.join(self.result_path, "init-images/")
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
# path for temp images e.g. gallery generations which are not committed
self.temp_image_path = os.path.join(self.result_path, "temp-images/")
# path for thumbnail images
self.thumbnail_image_path = os.path.join(
self.result_path, "thumbnails/")
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
# txt log
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
# make all output paths
@@ -298,21 +304,22 @@ class InvokeAIWebServer:
config["infill_methods"] = infill_methods()
socketio.emit("systemConfig", config)
@socketio.on('searchForModels')
@socketio.on("searchForModels")
def handle_search_models(search_folder: str):
try:
if not search_folder:
socketio.emit(
"foundModels",
{'search_folder': None, 'found_models': None},
{"search_folder": None, "found_models": None},
)
else:
search_folder, found_models = self.generate.model_manager.search_models(
search_folder)
(
search_folder,
found_models,
) = self.generate.model_manager.search_models(search_folder)
socketio.emit(
"foundModels",
{'search_folder': search_folder,
'found_models': found_models},
{"search_folder": search_folder, "found_models": found_models},
)
except Exception as e:
self.handle_exceptions(e)
@@ -321,11 +328,11 @@ class InvokeAIWebServer:
@socketio.on("addNewModel")
def handle_add_model(new_model_config: dict):
try:
model_name = new_model_config['name']
del new_model_config['name']
model_name = new_model_config["name"]
del new_model_config["name"]
model_attributes = new_model_config
if len(model_attributes['vae']) == 0:
del model_attributes['vae']
if len(model_attributes["vae"]) == 0:
del model_attributes["vae"]
update = False
current_model_list = self.generate.model_manager.list_models()
if model_name in current_model_list:
@@ -334,14 +341,20 @@ class InvokeAIWebServer:
print(f">> Adding New Model: {model_name}")
self.generate.model_manager.add_model(
model_name=model_name, model_attributes=model_attributes, clobber=True)
model_name=model_name,
model_attributes=model_attributes,
clobber=True,
)
self.generate.model_manager.commit(opt.conf)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"newModelAdded",
{"new_model_name": model_name,
"model_list": new_model_list, 'update': update},
{
"new_model_name": model_name,
"model_list": new_model_list,
"update": update,
},
)
print(f">> New Model Added: {model_name}")
except Exception as e:
@@ -356,8 +369,10 @@ class InvokeAIWebServer:
updated_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelDeleted",
{"deleted_model_name": model_name,
"model_list": updated_model_list},
{
"deleted_model_name": model_name,
"model_list": updated_model_list,
},
)
print(f">> Model Deleted: {model_name}")
except Exception as e:
@@ -382,41 +397,48 @@ class InvokeAIWebServer:
except Exception as e:
self.handle_exceptions(e)
@socketio.on('convertToDiffusers')
@socketio.on("convertToDiffusers")
def convert_to_diffusers(model_to_convert: dict):
try:
if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['model_name'])):
if 'weights' in model_info:
ckpt_path = Path(model_info['weights'])
original_config_file = Path(model_info['config'])
model_name = model_to_convert['model_name']
model_description = model_info['description']
if model_info := self.generate.model_manager.model_info(
model_name=model_to_convert["model_name"]
):
if "weights" in model_info:
ckpt_path = Path(model_info["weights"])
original_config_file = Path(model_info["config"])
model_name = model_to_convert["model_name"]
model_description = model_info["description"]
else:
self.socketio.emit(
"error", {"message": "Model is not a valid checkpoint file"})
"error", {"message": "Model is not a valid checkpoint file"}
)
else:
self.socketio.emit(
"error", {"message": "Could not retrieve model info."})
"error", {"message": "Could not retrieve model info."}
)
if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root, ckpt_path)
if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(
Globals.root, original_config_file)
original_config_file = Path(Globals.root, original_config_file)
diffusers_path = Path(
ckpt_path.parent.absolute(),
f'{model_name}_diffusers'
ckpt_path.parent.absolute(), f"{model_name}_diffusers"
)
if model_to_convert['save_location'] == 'root':
if model_to_convert["save_location"] == "root":
diffusers_path = Path(
global_converted_ckpts_dir(), f'{model_name}_diffusers')
global_converted_ckpts_dir(), f"{model_name}_diffusers"
)
if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None:
if (
model_to_convert["save_location"] == "custom"
and model_to_convert["custom_location"] is not None
):
diffusers_path = Path(
model_to_convert['custom_location'], f'{model_name}_diffusers')
model_to_convert["custom_location"], f"{model_name}_diffusers"
)
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
@@ -434,75 +456,89 @@ class InvokeAIWebServer:
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelConverted",
{"new_model_name": model_name,
"model_list": new_model_list, 'update': True},
{
"new_model_name": model_name,
"model_list": new_model_list,
"update": True,
},
)
print(f">> Model Converted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on('mergeDiffusersModels')
@socketio.on("mergeDiffusersModels")
def merge_diffusers_models(model_merge_info: dict):
try:
models_to_merge = model_merge_info['models_to_merge']
models_to_merge = model_merge_info["models_to_merge"]
model_ids_or_paths = [
self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
self.generate.model_manager.model_name_or_path(x)
for x in models_to_merge
]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force'])
model_ids_or_paths,
model_merge_info["alpha"],
model_merge_info["interp"],
model_merge_info["force"],
)
dump_path = global_models_dir() / 'merged_models'
if model_merge_info['model_merge_save_path'] is not None:
dump_path = Path(model_merge_info['model_merge_save_path'])
dump_path = global_models_dir() / "merged_models"
if model_merge_info["model_merge_save_path"] is not None:
dump_path = Path(model_merge_info["model_merge_save_path"])
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / model_merge_info['merged_model_name']
dump_path = dump_path / model_merge_info["merged_model_name"]
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
merged_model_config = dict(
model_name=model_merge_info['merged_model_name'],
model_name=model_merge_info["merged_model_name"],
description=f'Merge of models {", ".join(models_to_merge)}',
commit_to_conf=opt.conf
commit_to_conf=opt.conf,
)
if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
print(
f">> Using configured VAE assigned to {models_to_merge[0]}")
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
"vae", None
):
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
dump_path, **merged_model_config)
dump_path, **merged_model_config
)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelsMerged",
{"merged_models": models_to_merge,
"merged_model_name": model_merge_info['merged_model_name'],
"model_list": new_model_list, 'update': True},
{
"merged_models": models_to_merge,
"merged_model_name": model_merge_info["merged_model_name"],
"model_list": new_model_list,
"update": True,
},
)
print(f">> Models Merged: {models_to_merge}")
print(
f">> New Model Added: {model_merge_info['merged_model_name']}")
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on('getLoraModels')
@socketio.on("getLoraModels")
def get_lora_models():
try:
lora_path = global_lora_models_dir()
lora_folder_ckpt = Path(lora_path).glob("**/*.ckpt")
lora_folder_safetensors = Path(lora_path).glob("**/*.safetensors")
ckpt_loras = [x for x in lora_folder_ckpt if x.is_file()]
safetensors_loras = [x for x in lora_folder_safetensors if x.is_file()]
loras = ckpt_loras + safetensors_loras
loras = []
for root, _, files in os.walk(lora_path):
models = [
Path(root, x)
for x in files
if Path(x).suffix in [".ckpt", ".pt", ".safetensors"]
]
loras = loras + models
found_loras = []
for lora in loras:
location = str(lora.resolve()).replace("\\", "/")
found_loras.append({"name": lora.stem, "location": location})
socketio.emit('foundLoras', found_loras)
socketio.emit("foundLoras", found_loras)
except Exception as e:
self.handle_exceptions(e)
@@ -520,7 +556,8 @@ class InvokeAIWebServer:
os.remove(thumbnail_path)
except Exception as e:
socketio.emit(
"error", {"message": f"Unable to delete {f}: {str(e)}"})
"error", {"message": f"Unable to delete {f}: {str(e)}"}
)
pass
socketio.emit("tempFolderEmptied")
@@ -531,8 +568,7 @@ class InvokeAIWebServer:
def save_temp_image_to_gallery(url):
try:
image_path = self.get_image_path_from_url(url)
new_path = os.path.join(
self.result_path, os.path.basename(image_path))
new_path = os.path.join(self.result_path, os.path.basename(image_path))
shutil.copy2(image_path, new_path)
if os.path.splitext(new_path)[1] == ".png":
@@ -545,8 +581,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(
new_path), self.thumbnail_image_path
pil_image, os.path.basename(new_path), self.thumbnail_image_path
)
image_array = [
@@ -605,8 +640,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(
path), self.thumbnail_image_path
pil_image, os.path.basename(path), self.thumbnail_image_path
)
image_array.append(
@@ -625,7 +659,8 @@ class InvokeAIWebServer:
)
except Exception as e:
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
pass
socketio.emit(
@@ -675,8 +710,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(
path), self.thumbnail_image_path
pil_image, os.path.basename(path), self.thumbnail_image_path
)
image_array.append(
@@ -696,7 +730,8 @@ class InvokeAIWebServer:
except Exception as e:
print(f">> Unable to load {path}")
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"})
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
pass
socketio.emit(
@@ -730,10 +765,9 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..."
)
print(
f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
print(f'>> ESRGAN Parameters: {esrgan_parameters}')
print(f'>> Facetool Parameters: {facetool_parameters}')
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
print(f">> ESRGAN Parameters: {esrgan_parameters}")
print(f">> Facetool Parameters: {facetool_parameters}")
self.generate_images(
generation_parameters,
@@ -770,11 +804,9 @@ class InvokeAIWebServer:
if postprocessing_parameters["type"] == "esrgan":
progress.set_current_status("common.statusUpscalingESRGAN")
elif postprocessing_parameters["type"] == "gfpgan":
progress.set_current_status(
"common.statusRestoringFacesGFPGAN")
progress.set_current_status("common.statusRestoringFacesGFPGAN")
elif postprocessing_parameters["type"] == "codeformer":
progress.set_current_status(
"common.statusRestoringFacesCodeFormer")
progress.set_current_status("common.statusRestoringFacesCodeFormer")
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
@@ -939,8 +971,7 @@ class InvokeAIWebServer:
init_img_url = generation_parameters["init_img"]
original_bounding_box = generation_parameters["bounding_box"].copy(
)
original_bounding_box = generation_parameters["bounding_box"].copy()
initial_image = dataURL_to_image(
generation_parameters["init_img"]
@@ -1017,8 +1048,9 @@ class InvokeAIWebServer:
elif generation_parameters["generation_mode"] == "img2img":
init_img_url = generation_parameters["init_img"]
init_img_path = self.get_image_path_from_url(init_img_url)
generation_parameters["init_img"] = Image.open(
init_img_path).convert('RGB')
generation_parameters["init_img"] = Image.open(init_img_path).convert(
"RGB"
)
def image_progress(sample, step):
if self.canceled.is_set():
@@ -1078,8 +1110,7 @@ class InvokeAIWebServer:
)
if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(
sample)
image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
@@ -1098,8 +1129,7 @@ class InvokeAIWebServer:
},
)
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
def image_done(image, seed, first_seed, attention_maps_image=None):
@@ -1126,8 +1156,7 @@ class InvokeAIWebServer:
progress.set_current_status("common.statusGenerationComplete")
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
all_parameters = generation_parameters
@@ -1138,8 +1167,7 @@ class InvokeAIWebServer:
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [
[seed, all_parameters["variation_amount"]]]
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = (
prior_variations + this_variation
)
@@ -1155,14 +1183,13 @@ class InvokeAIWebServer:
if esrgan_parameters:
progress.set_current_status("common.statusUpscaling")
progress.set_current_status_has_steps(False)
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
image = self.esrgan.process(
image=image,
upsampler_scale=esrgan_parameters["level"],
denoise_str=esrgan_parameters['denoise_str'],
denoise_str=esrgan_parameters["denoise_str"],
strength=esrgan_parameters["strength"],
seed=seed,
)
@@ -1170,7 +1197,7 @@ class InvokeAIWebServer:
postprocessing = True
all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters['denoise_str'],
esrgan_parameters["denoise_str"],
esrgan_parameters["strength"],
]
@@ -1179,15 +1206,14 @@ class InvokeAIWebServer:
if facetool_parameters:
if facetool_parameters["type"] == "gfpgan":
progress.set_current_status(
"common.statusRestoringFacesGFPGAN")
progress.set_current_status("common.statusRestoringFacesGFPGAN")
elif facetool_parameters["type"] == "codeformer":
progress.set_current_status(
"common.statusRestoringFacesCodeFormer")
"common.statusRestoringFacesCodeFormer"
)
progress.set_current_status_has_steps(False)
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
if facetool_parameters["type"] == "gfpgan":
@@ -1217,8 +1243,7 @@ class InvokeAIWebServer:
all_parameters["facetool_type"] = facetool_parameters["type"]
progress.set_current_status("common.statusSavingImage")
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
# restore the stashed URLS and discard the paths, we are about to send the result to client
@@ -1235,8 +1260,7 @@ class InvokeAIWebServer:
if generation_parameters["generation_mode"] == "unifiedCanvas":
all_parameters["bounding_box"] = original_bounding_box
metadata = self.parameters_to_generated_image_metadata(
all_parameters)
metadata = self.parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters)
@@ -1266,22 +1290,27 @@ class InvokeAIWebServer:
if progress.total_iterations > progress.current_iteration:
progress.set_current_step(1)
progress.set_current_status(
"common.statusIterationComplete")
progress.set_current_status("common.statusIterationComplete")
progress.set_current_status_has_steps(False)
else:
progress.mark_complete()
self.socketio.emit(
"progressUpdate", progress.to_formatted_dict())
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(
generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = (
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt
)
)
attention_maps_image_base64_url = (
None
if attention_maps_image is None
else image_to_dataURL(attention_maps_image)
)
self.socketio.emit(
"generationResult",
@@ -1313,7 +1342,7 @@ class InvokeAIWebServer:
self.generate.prompt2image(
**generation_parameters,
step_callback=diffusers_step_callback_adapter,
image_callback=image_done
image_callback=image_done,
)
except KeyboardInterrupt:
@@ -1436,8 +1465,7 @@ class InvokeAIWebServer:
self, parameters, original_image_path
):
try:
current_metadata = retrieve_metadata(
original_image_path)["sd-metadata"]
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
postprocessing_metadata = {}
"""
@@ -1477,8 +1505,7 @@ class InvokeAIWebServer:
postprocessing_metadata
)
else:
current_metadata["image"]["postprocessing"] = [
postprocessing_metadata]
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
return current_metadata
@@ -1574,8 +1601,7 @@ class InvokeAIWebServer:
)
elif "thumbnails" in url:
return os.path.abspath(
os.path.join(self.thumbnail_image_path,
os.path.basename(url))
os.path.join(self.thumbnail_image_path, os.path.basename(url))
)
else:
return os.path.abspath(
@@ -1621,7 +1647,7 @@ class InvokeAIWebServer:
except Exception as e:
self.handle_exceptions(e)
def handle_exceptions(self, exception, emit_key: str = 'error'):
def handle_exceptions(self, exception, emit_key: str = "error"):
self.socketio.emit(emit_key, {"message": (str(exception))})
print("\n")
traceback.print_exc()