From 8b8cc7fd33b4caf5f214035d4132f245be489b18 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 21 Mar 2023 19:22:20 +0530 Subject: [PATCH] [SD] Update LoRA inference to handle various checkpoints (#1215) --- apps/stable_diffusion/scripts/img2img.py | 22 +--- apps/stable_diffusion/scripts/inpaint.py | 22 +--- apps/stable_diffusion/scripts/outpaint.py | 20 +-- apps/stable_diffusion/scripts/txt2img.py | 23 +--- apps/stable_diffusion/scripts/upscaler.py | 15 ++- .../src/models/model_wrappers.py | 20 ++- apps/stable_diffusion/src/utils/__init__.py | 1 + .../src/utils/sd_annotation.py | 21 +-- apps/stable_diffusion/src/utils/utils.py | 120 +++++++++++++++++- .../web/ui/css/sd_dark_theme.css | 1 + apps/stable_diffusion/web/ui/img2img_ui.py | 4 +- apps/stable_diffusion/web/ui/inpaint_ui.py | 4 +- apps/stable_diffusion/web/ui/outpaint_ui.py | 4 +- apps/stable_diffusion/web/ui/txt2img_ui.py | 4 +- apps/stable_diffusion/web/ui/upscaler_ui.py | 17 +++ apps/stable_diffusion/web/ui/utils.py | 38 +++++- requirements.txt | 2 +- 17 files changed, 240 insertions(+), 98 deletions(-) diff --git a/apps/stable_diffusion/scripts/img2img.py b/apps/stable_diffusion/scripts/img2img.py index 7c93447f..e50c0b86 100644 --- a/apps/stable_diffusion/scripts/img2img.py +++ b/apps/stable_diffusion/scripts/img2img.py @@ -77,6 +77,7 @@ def img2img_inf( ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, + get_custom_vae_or_lora_weights, Config, ) import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -100,10 +101,6 @@ def img2img_inf( image = init_image.convert("RGB") # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types args.ckpt_loc = "" args.hf_model_id = "" if custom_model == "None": @@ -118,14 +115,9 @@ def img2img_inf( else: args.hf_model_id = custom_model - use_lora = "" - if lora_weights == "None" and not lora_hf_id: - use_lora = "" - elif not lora_hf_id: - use_lora = lora_weights - else: - use_lora = lora_hf_id - args.use_lora = use_lora + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png @@ -159,7 +151,7 @@ def img2img_inf( height, width, device, - use_lora=use_lora, + use_lora=args.use_lora, use_stencil=use_stencil, ) if ( @@ -205,7 +197,7 @@ def img2img_inf( low_cpu_mem_usage=args.low_cpu_mem_usage, use_stencil=use_stencil, debug=args.import_debug if args.import_mlir else False, - use_lora=use_lora, + use_lora=args.use_lora, ) ) else: @@ -225,7 +217,7 @@ def img2img_inf( args.use_tuned, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, - use_lora=use_lora, + use_lora=args.use_lora, ) ) diff --git a/apps/stable_diffusion/scripts/inpaint.py b/apps/stable_diffusion/scripts/inpaint.py index 158eef1b..eb7cab4f 100644 --- a/apps/stable_diffusion/scripts/inpaint.py +++ b/apps/stable_diffusion/scripts/inpaint.py @@ -48,6 +48,7 @@ def inpaint_inf( ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, + get_custom_vae_or_lora_weights, Config, ) import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -66,10 +67,6 @@ def inpaint_inf( args.mask_path = "not none" # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types args.ckpt_loc = "" args.hf_model_id = "" if custom_model == "None": @@ -84,14 +81,9 @@ def inpaint_inf( else: args.hf_model_id = custom_model - use_lora = "" - if lora_weights == "None" and not lora_hf_id: - use_lora = "" - elif not lora_hf_id: - use_lora = lora_weights - else: - use_lora = lora_hf_id - args.use_lora = use_lora + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png @@ -108,7 +100,7 @@ def inpaint_inf( height, width, device, - use_lora=use_lora, + use_lora=args.use_lora, use_stencil=None, ) if ( @@ -148,10 +140,9 @@ def inpaint_inf( width=args.width, use_base_vae=args.use_base_vae, use_tuned=args.use_tuned, - custom_vae=args.custom_vae, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, - use_lora=use_lora, + use_lora=args.use_lora, ) ) @@ -239,7 +230,6 @@ if __name__ == "__main__": width=args.width, use_base_vae=args.use_base_vae, use_tuned=args.use_tuned, - custom_vae=args.custom_vae, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, use_lora=args.use_lora, diff --git a/apps/stable_diffusion/scripts/outpaint.py b/apps/stable_diffusion/scripts/outpaint.py index d3fcf92d..2a922317 100644 --- a/apps/stable_diffusion/scripts/outpaint.py +++ b/apps/stable_diffusion/scripts/outpaint.py @@ -51,6 +51,7 @@ def outpaint_inf( ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, + get_custom_vae_or_lora_weights, Config, ) import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -68,10 +69,6 @@ def outpaint_inf( args.img_path = "not none" # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types args.ckpt_loc = "" args.hf_model_id = "" if custom_model == "None": @@ -86,14 +83,9 @@ def outpaint_inf( else: args.hf_model_id = custom_model - use_lora = "" - if lora_weights == "None" and not lora_hf_id: - use_lora = "" - elif not lora_hf_id: - use_lora = lora_weights - else: - use_lora = lora_hf_id - args.use_lora = use_lora + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png @@ -110,7 +102,7 @@ def outpaint_inf( height, width, device, - use_lora=use_lora, + use_lora=args.use_lora, use_stencil=None, ) if ( @@ -151,7 +143,7 @@ def outpaint_inf( args.width, args.use_base_vae, args.use_tuned, - use_lora=use_lora, + use_lora=args.use_lora, ) ) diff --git a/apps/stable_diffusion/scripts/txt2img.py b/apps/stable_diffusion/scripts/txt2img.py index fa16beda..ae80247b 100644 --- a/apps/stable_diffusion/scripts/txt2img.py +++ b/apps/stable_diffusion/scripts/txt2img.py @@ -43,6 +43,7 @@ def txt2img_inf( ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, + get_custom_vae_or_lora_weights, Config, ) import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -59,10 +60,6 @@ def txt2img_inf( args.scheduler = scheduler # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types args.ckpt_loc = "" args.hf_model_id = "" if custom_model == "None": @@ -80,14 +77,9 @@ def txt2img_inf( args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png - use_lora = "" - if lora_weights == "None" and not lora_hf_id: - use_lora = "" - elif not lora_hf_id: - use_lora = lora_weights - else: - use_lora = lora_hf_id - args.use_lora = use_lora + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) dtype = torch.float32 if precision == "fp32" else torch.half cpu_scheduling = not scheduler.startswith("Shark") @@ -101,7 +93,7 @@ def txt2img_inf( height, width, device, - use_lora=use_lora, + use_lora=args.use_lora, use_stencil=None, ) if ( @@ -145,7 +137,7 @@ def txt2img_inf( custom_vae=args.custom_vae, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, - use_lora=use_lora, + use_lora=args.use_lora, ) ) @@ -200,7 +192,6 @@ if __name__ == "__main__": schedulers = get_schedulers(args.hf_model_id) scheduler_obj = schedulers[args.scheduler] seed = args.seed - use_lora = args.use_lora txt2img_obj = Text2ImagePipeline.from_pretrained( scheduler=scheduler_obj, import_mlir=args.import_mlir, @@ -216,7 +207,7 @@ if __name__ == "__main__": custom_vae=args.custom_vae, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, - use_lora=use_lora, + use_lora=args.use_lora, ) for current_batch in range(args.batch_count): diff --git a/apps/stable_diffusion/scripts/upscaler.py b/apps/stable_diffusion/scripts/upscaler.py index 0a4fa592..ce5d1bf1 100644 --- a/apps/stable_diffusion/scripts/upscaler.py +++ b/apps/stable_diffusion/scripts/upscaler.py @@ -41,9 +41,12 @@ def upscaler_inf( max_length: int, save_metadata_to_json: bool, save_metadata_to_png: bool, + lora_weights: str, + lora_hf_id: str, ): from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, + get_custom_vae_or_lora_weights, Config, ) import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -62,10 +65,6 @@ def upscaler_inf( image = init_image.convert("RGB").resize((height, width)) # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types args.ckpt_loc = "" args.hf_model_id = "" if custom_model == "None": @@ -83,6 +82,10 @@ def upscaler_inf( args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) + dtype = torch.float32 if precision == "fp32" else torch.half cpu_scheduling = not scheduler.startswith("Shark") args.height = 128 @@ -97,7 +100,7 @@ def upscaler_inf( args.height, args.width, device, - use_lora=None, + use_lora=args.use_lora, use_stencil=None, ) if ( @@ -135,6 +138,7 @@ def upscaler_inf( args.use_base_vae, args.use_tuned, low_cpu_mem_usage=args.low_cpu_mem_usage, + use_lora=args.use_lora, ) ) @@ -232,6 +236,7 @@ if __name__ == "__main__": args.use_base_vae, args.use_tuned, low_cpu_mem_usage=args.low_cpu_mem_usage, + use_lora=args.use_lora, ddpm_scheduler=schedulers["DDPM"], ) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 96417e27..8bb5d09b 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import ( get_path_stem, get_extended_name, get_stencil_model_id, + update_lora_weight, ) @@ -200,6 +201,7 @@ class SharkifyStableDiffusionModel: use_tuned=self.use_tuned, model_name=self.model_name["vae_encode"], extra_args=get_opt_flags("vae", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_vae_encode @@ -255,6 +257,7 @@ class SharkifyStableDiffusionModel: generate_vmfb=self.generate_vmfb, save_dir=save_dir, extra_args=get_opt_flags("vae", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_vae @@ -281,13 +284,14 @@ class SharkifyStableDiffusionModel: use_tuned=self.use_tuned, model_name=self.model_name["vae"], extra_args=get_opt_flags("vae", precision="fp32"), + base_model_id=self.base_model_id, ) return shark_vae def get_controlled_unet(self): class ControlledUnetModel(torch.nn.Module): def __init__( - self, model_id=self.model_id, low_cpu_mem_usage=False + self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora ): super().__init__() self.unet = UNet2DConditionModel.from_pretrained( @@ -295,6 +299,8 @@ class SharkifyStableDiffusionModel: subfolder="unet", low_cpu_mem_usage=low_cpu_mem_usage, ) + if use_lora != "": + update_lora_weight(self.unet, use_lora, "unet") self.in_channels = self.unet.in_channels self.train(False) @@ -333,6 +339,7 @@ class SharkifyStableDiffusionModel: f16_input_mask=input_mask, use_tuned=self.use_tuned, extra_args=get_opt_flags("unet", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_controlled_unet @@ -386,6 +393,7 @@ class SharkifyStableDiffusionModel: f16_input_mask=input_mask, use_tuned=self.use_tuned, extra_args=get_opt_flags("unet", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_cnet @@ -399,7 +407,7 @@ class SharkifyStableDiffusionModel: low_cpu_mem_usage=low_cpu_mem_usage, ) if use_lora != "": - self.unet.load_attn_procs(use_lora) + update_lora_weight(self.unet, use_lora, "unet") self.in_channels = self.unet.in_channels self.train(False) if(args.attention_slicing is not None and args.attention_slicing != "none"): @@ -444,6 +452,7 @@ class SharkifyStableDiffusionModel: generate_vmfb=self.generate_vmfb, save_dir=save_dir, extra_args=get_opt_flags("unet", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_unet @@ -481,18 +490,21 @@ class SharkifyStableDiffusionModel: f16_input_mask=input_mask, use_tuned=self.use_tuned, extra_args=get_opt_flags("unet", precision=self.precision), + base_model_id=self.base_model_id, ) return shark_unet def get_clip(self): class CLIPText(torch.nn.Module): - def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False): + def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora): super().__init__() self.text_encoder = CLIPTextModel.from_pretrained( model_id, subfolder="text_encoder", low_cpu_mem_usage=low_cpu_mem_usage, ) + if use_lora != "": + update_lora_weight(self.text_encoder, use_lora, "text_encoder") def forward(self, input): return self.text_encoder(input)[0] @@ -512,6 +524,7 @@ class SharkifyStableDiffusionModel: generate_vmfb=self.generate_vmfb, save_dir=save_dir, extra_args=get_opt_flags("clip", precision="fp32"), + base_model_id=self.base_model_id, ) return shark_clip @@ -539,6 +552,7 @@ class SharkifyStableDiffusionModel: # Compiles Clip, Unet and Vae with `base_model_id` as defining their input # configiration. def compile_all(self, base_model_id, need_vae_encode, need_stencil): + self.base_model_id = base_model_id self.inputs = get_input_info( base_models[base_model_id], self.max_len, diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py index 0a846c6d..5a749250 100644 --- a/apps/stable_diffusion/src/utils/__init__.py +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -33,4 +33,5 @@ from apps.stable_diffusion.src.utils.utils import ( clear_all, save_output_img, get_generation_text_info, + update_lora_weight, ) diff --git a/apps/stable_diffusion/src/utils/sd_annotation.py b/apps/stable_diffusion/src/utils/sd_annotation.py index 96c2964e..262e6ce9 100644 --- a/apps/stable_diffusion/src/utils/sd_annotation.py +++ b/apps/stable_diffusion/src/utils/sd_annotation.py @@ -76,18 +76,19 @@ def load_winograd_configs(): return winograd_config_dir -def load_lower_configs(): +def load_lower_configs(base_model_id=None): from apps.stable_diffusion.src.models import get_variant_version from apps.stable_diffusion.src.utils.utils import ( fetch_and_update_base_model_id, ) - if args.ckpt_loc != "": - base_model_id = fetch_and_update_base_model_id(args.ckpt_loc) - else: - base_model_id = fetch_and_update_base_model_id(args.hf_model_id) - if base_model_id == "": - base_model_id = args.hf_model_id + if not base_model_id: + if args.ckpt_loc != "": + base_model_id = fetch_and_update_base_model_id(args.ckpt_loc) + else: + base_model_id = fetch_and_update_base_model_id(args.hf_model_id) + if base_model_id == "": + base_model_id = args.hf_model_id variant, version = get_variant_version(base_model_id) @@ -212,7 +213,7 @@ def annotate_with_lower_configs( return bytecode -def sd_model_annotation(mlir_model, model_name): +def sd_model_annotation(mlir_model, model_name, base_model_id=None): device = get_device() if args.annotation_model == "unet" and device == "vulkan": use_winograd = True @@ -220,7 +221,7 @@ def sd_model_annotation(mlir_model, model_name): winograd_model = annotate_with_winograd( mlir_model, winograd_config_dir, model_name ) - lowering_config_dir = load_lower_configs() + lowering_config_dir = load_lower_configs(base_model_id) tuned_model = annotate_with_lower_configs( winograd_model, lowering_config_dir, model_name, use_winograd ) @@ -232,7 +233,7 @@ def sd_model_annotation(mlir_model, model_name): ) else: use_winograd = False - lowering_config_dir = load_lower_configs() + lowering_config_dir = load_lower_configs(base_model_id) tuned_model = annotate_with_lower_configs( mlir_model, lowering_config_dir, model_name, use_winograd ) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 9fcf8f50..9b56dce6 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -9,6 +9,8 @@ from pathlib import Path import numpy as np from random import randint import tempfile +import torch +from safetensors.torch import load_file from shark.shark_inference import SharkInference from shark.shark_importer import import_with_fx from shark.iree_utils.vulkan_utils import ( @@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation import sys from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - load_pipeline_from_original_stable_diffusion_ckpt, + download_from_original_stable_diffusion_ckpt, ) @@ -95,6 +97,7 @@ def compile_through_fx( debug=False, generate_vmfb=True, extra_args=[], + base_model_id=None, ): from shark.parser import shark_args @@ -116,7 +119,9 @@ def compile_through_fx( if use_tuned: if "vae" in model_name.split("_")[0]: args.annotation_model = "vae" - mlir_module = sd_model_annotation(mlir_module, model_name) + mlir_module = sd_model_annotation( + mlir_module, model_name, base_model_id + ) shark_module = SharkInference( mlir_module, @@ -454,7 +459,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False): "Loading diffusers' pipeline from original stable diffusion checkpoint" ) num_in_channels = 9 if is_inpaint else 4 - pipe = load_pipeline_from_original_stable_diffusion_ckpt( + pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path=custom_weights, extract_ema=extract_ema, from_safetensors=from_safetensors, @@ -464,6 +469,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False): print("Loading complete") +def processLoRA(model, use_lora, splitting_prefix): + state_dict = "" + if ".safetensors" in use_lora: + state_dict = load_file(use_lora) + else: + state_dict = torch.load(use_lora) + alpha = 0.75 + visited = [] + + # directly update weight in model + process_unet = "te" not in splitting_prefix + for key in state_dict: + if ".alpha" in key or key in visited: + continue + + curr_layer = model + if ("text" not in key and process_unet) or ( + "text" in key and not process_unet + ): + layer_infos = ( + key.split(".")[0].split(splitting_prefix)[-1].split("_") + ) + else: + continue + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = ( + state_dict[pair_keys[0]] + .squeeze(3) + .squeeze(2) + .to(torch.float32) + ) + weight_down = ( + state_dict[pair_keys[1]] + .squeeze(3) + .squeeze(2) + .to(torch.float32) + ) + curr_layer.weight.data += alpha * torch.mm( + weight_up, weight_down + ).unsqueeze(2).unsqueeze(3) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) + # update visited list + for item in pair_keys: + visited.append(item) + return model + + +def update_lora_weight_for_unet(unet, use_lora): + extensions = [".bin", ".safetensors", ".pt"] + if not any([extension in use_lora for extension in extensions]): + # We assume if it is a HF ID with standalone LoRA weights. + unet.load_attn_procs(use_lora) + return unet + + main_file_name = get_path_stem(use_lora) + if ".bin" in use_lora: + main_file_name += ".bin" + elif ".safetensors" in use_lora: + main_file_name += ".safetensors" + elif ".pt" in use_lora: + main_file_name += ".pt" + else: + sys.exit("Only .bin and .safetensors format for LoRA is supported") + + try: + dir_name = os.path.dirname(use_lora) + unet.load_attn_procs(dir_name, weight_name=main_file_name) + return unet + except: + return processLoRA(unet, use_lora, "lora_unet_") + + +def update_lora_weight(model, use_lora, model_name): + if "unet" in model_name: + return update_lora_weight_for_unet(model, use_lora) + try: + return processLoRA(model, use_lora, "lora_te_") + except: + return None + + def load_vmfb(vmfb_path, model, precision): model = "vae" if "base_vae" in model or "vae_encode" in model else model model = "unet" if "stencil" in model else model diff --git a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css index 01557ad1..be30a8ca 100644 --- a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css +++ b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css @@ -178,6 +178,7 @@ footer { /* Hide "remove buttons" from ui dropdowns */ #custom_model .token-remove.remove-all, +#lora_weights .token-remove.remove-all, #scheduler .token-remove.remove-all, #device .token-remove.remove-all, #stencil_model .token-remove.remove-all { diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index cd90b7cb..537a9b9d 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -77,10 +77,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web: with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): lora_weights = gr.Dropdown( - label=f"Standlone LoRA weights (Path: {get_custom_model_path()})", + label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})", elem_id="lora_weights", value="None", - choices=["None"] + get_custom_model_files(), + choices=["None"] + get_custom_model_files("lora"), ) lora_hf_id = gr.Textbox( elem_id="lora_hf_id", diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 50de1a93..9434ee7f 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -72,10 +72,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web: with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): lora_weights = gr.Dropdown( - label=f"Standlone LoRA weights (Path: {get_custom_model_path()})", + label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})", elem_id="lora_weights", value="None", - choices=["None"] + get_custom_model_files(), + choices=["None"] + get_custom_model_files("lora"), ) lora_hf_id = gr.Textbox( elem_id="lora_hf_id", diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index cdc0a944..e15fb669 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -69,10 +69,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web: with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): lora_weights = gr.Dropdown( - label=f"Standlone LoRA weights (Path: {get_custom_model_path()})", + label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})", elem_id="lora_weights", value="None", - choices=["None"] + get_custom_model_files(), + choices=["None"] + get_custom_model_files("lora"), ) lora_hf_id = gr.Textbox( elem_id="lora_hf_id", diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 3aab6323..9a654c5d 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -73,10 +73,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web: with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): lora_weights = gr.Dropdown( - label=f"Standlone LoRA weights (Path: {get_custom_model_path()})", + label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})", elem_id="lora_weights", value="None", - choices=["None"] + get_custom_model_files(), + choices=["None"] + get_custom_model_files("lora"), ) lora_hf_id = gr.Textbox( elem_id="lora_hf_id", diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index dc775f07..0841be6b 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web: label="Input Image", type="pil" ).style(height=300) + with gr.Accordion(label="LoRA Options", open=False): + with gr.Row(): + lora_weights = gr.Dropdown( + label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})", + elem_id="lora_weights", + value="None", + choices=["None"] + get_custom_model_files("lora"), + ) + lora_hf_id = gr.Textbox( + elem_id="lora_hf_id", + placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4", + value="", + label="HuggingFace Model ID", + lines=3, + ) with gr.Accordion(label="Advanced Options", open=False): with gr.Row(): scheduler = gr.Dropdown( @@ -226,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web: max_length, save_metadata_to_json, save_metadata_to_png, + lora_weights, + lora_hf_id, ], outputs=[upscaler_gallery, std_output], show_progress=args.progress_bar, diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 51cdd1f1..4b4c1ec0 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -74,25 +74,49 @@ def resource_path(relative_path): return os.path.join(base_path, relative_path) -def get_custom_model_path(): - return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models") +def get_custom_model_path(model="models"): + match model: + case "models": + return Path(Path.cwd(), "models") + case "vae": + return Path(Path.cwd(), "models/vae") + case "lora": + return Path(Path.cwd(), "models/lora") + case _: + return "" -def get_custom_model_pathfile(custom_model_name): - return os.path.join(get_custom_model_path(), custom_model_name) +def get_custom_model_pathfile(custom_model_name, model="models"): + return os.path.join(get_custom_model_path(model), custom_model_name) -def get_custom_model_files(): +def get_custom_model_files(model="models"): ckpt_files = [] - for extn in custom_model_filetypes: + file_types = custom_model_filetypes + if model == "lora": + file_types = custom_model_filetypes + ("*.pt", "*.bin") + for extn in file_types: files = [ os.path.basename(x) - for x in glob.glob(os.path.join(get_custom_model_path(), extn)) + for x in glob.glob( + os.path.join(get_custom_model_path(model), extn) + ) ] ckpt_files.extend(files) return sorted(ckpt_files, key=str.casefold) +def get_custom_vae_or_lora_weights(weights, hf_id, model): + use_weight = "" + if weights == "None" and not hf_id: + use_weight = "" + elif not hf_id: + use_weight = get_custom_model_pathfile(weights, model) + else: + use_weight = hf_id + return use_weight + + def cancel_sd(): # Try catch it, as gc can delete global_obj.sd_obj while switching model try: diff --git a/requirements.txt b/requirements.txt index ae4844ae..6065b04a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ parameterized # Add transformers, diffusers and scipy since it most commonly used transformers -diffusers +diffusers @ git+https://github.com/huggingface/diffusers@main scipy ftfy gradio