mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[SD] Update LoRA inference to handle various checkpoints (#1215)
This commit is contained in:
@@ -77,6 +77,7 @@ def img2img_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -100,10 +101,6 @@ def img2img_inf(
|
||||
image = init_image.convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -118,14 +115,9 @@ def img2img_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -159,7 +151,7 @@ def img2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
if (
|
||||
@@ -205,7 +197,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -225,7 +217,7 @@ def img2img_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def inpaint_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -66,10 +67,6 @@ def inpaint_inf(
|
||||
args.mask_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -84,14 +81,9 @@ def inpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -108,7 +100,7 @@ def inpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -148,10 +140,9 @@ def inpaint_inf(
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -239,7 +230,6 @@ if __name__ == "__main__":
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
|
||||
@@ -51,6 +51,7 @@ def outpaint_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -68,10 +69,6 @@ def outpaint_inf(
|
||||
args.img_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -86,14 +83,9 @@ def outpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -110,7 +102,7 @@ def outpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -151,7 +143,7 @@ def outpaint_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ def txt2img_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -59,10 +60,6 @@ def txt2img_inf(
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -80,14 +77,9 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -101,7 +93,7 @@ def txt2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -145,7 +137,7 @@ def txt2img_inf(
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -200,7 +192,6 @@ if __name__ == "__main__":
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
use_lora = args.use_lora
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
@@ -216,7 +207,7 @@ if __name__ == "__main__":
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
@@ -41,9 +41,12 @@ def upscaler_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -62,10 +65,6 @@ def upscaler_inf(
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -83,6 +82,10 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
@@ -97,7 +100,7 @@ def upscaler_inf(
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -135,6 +138,7 @@ def upscaler_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -232,6 +236,7 @@ if __name__ == "__main__":
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
|
||||
@@ -200,6 +201,7 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
@@ -255,6 +257,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
@@ -281,13 +284,14 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_controlled_unet(self):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -295,6 +299,8 @@ class SharkifyStableDiffusionModel:
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -333,6 +339,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_controlled_unet
|
||||
|
||||
@@ -386,6 +393,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_cnet
|
||||
|
||||
@@ -399,7 +407,7 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
self.unet.load_attn_procs(use_lora)
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
@@ -444,6 +452,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -481,18 +490,21 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -512,6 +524,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
@@ -539,6 +552,7 @@ class SharkifyStableDiffusionModel:
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
|
||||
self.base_model_id = base_model_id
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
|
||||
@@ -33,4 +33,5 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
@@ -76,18 +76,19 @@ def load_winograd_configs():
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
def load_lower_configs(base_model_id=None):
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
if not base_model_id:
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
@@ -212,7 +213,7 @@ def annotate_with_lower_configs(
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
@@ -220,7 +221,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
@@ -232,7 +233,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,6 +97,7 @@ def compile_through_fx(
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -116,7 +119,9 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, model_name, base_model_id
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
@@ -454,7 +459,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
@@ -464,6 +469,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
|
||||
@@ -178,6 +178,7 @@ footer {
|
||||
|
||||
/* Hide "remove buttons" from ui dropdowns */
|
||||
#custom_model .token-remove.remove-all,
|
||||
#lora_weights .token-remove.remove-all,
|
||||
#scheduler .token-remove.remove-all,
|
||||
#device .token-remove.remove-all,
|
||||
#stencil_model .token-remove.remove-all {
|
||||
|
||||
@@ -77,10 +77,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -72,10 +72,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -69,10 +69,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -73,10 +73,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -226,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
|
||||
@@ -74,25 +74,49 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_custom_model_path():
|
||||
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
|
||||
def get_custom_model_path(model="models"):
|
||||
match model:
|
||||
case "models":
|
||||
return Path(Path.cwd(), "models")
|
||||
case "vae":
|
||||
return Path(Path.cwd(), "models/vae")
|
||||
case "lora":
|
||||
return Path(Path.cwd(), "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name):
|
||||
return os.path.join(get_custom_model_path(), custom_model_name)
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files():
|
||||
def get_custom_model_files(model="models"):
|
||||
ckpt_files = []
|
||||
for extn in custom_model_filetypes:
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
file_types = custom_model_filetypes + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
|
||||
for x in glob.glob(
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
return use_weight
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@main
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
|
||||
Reference in New Issue
Block a user