SD/UI: Merge Lora Selection Boxes, Add LoRA Strength (#2052)

* Merges LoRA selection in the UI into a single selection, rather than
one for LoRAs under ./models and another for Hugging Face Id
* Add LoRA strength to UI and pipeline parameters.
* Add a `--lora_strength` command line argument.
* Bake LoRA strength into .vmfb naming when a LoRA is specified.
* Use LoRA embedded alpha values and (up tensor dimension *
LoRA strength) for final alpha when applying LoRA weights rather
than a hardcoded value of 0.75
* Adds additional cases to the LoRA weight application that are
present for weight application in the Kohya scripts.
* Include lora strength when reading and writing png metadata.
* Allow lora_strength to be set above 1.0 in the UI, so similar effects
to the prior (overdriven alpha) implementation can be obtained.
This commit is contained in:
Stefan Kapusniak
2024-01-04 00:59:47 +00:00
committed by GitHub
parent 16c03e4b44
commit 8d9b5b3afa
24 changed files with 394 additions and 256 deletions

View File

@@ -159,6 +159,7 @@ class SharkifyStableDiffusionModel:
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
lora_strength: float = 0.75,
use_quantize: str = None,
return_mlir: bool = False,
):
@@ -216,8 +217,14 @@ class SharkifyStableDiffusionModel:
self.is_upscaler = is_upscaler
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.model_name = (
self.model_name
+ "_"
+ get_path_stem(use_lora)
+ f"@{int(lora_strength*100)}"
)
self.use_lora = use_lora
self.lora_strength = lora_strength
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
@@ -534,6 +541,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -542,7 +550,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
@@ -818,6 +828,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -826,7 +837,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
if (
@@ -1058,6 +1071,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
@@ -1067,7 +1081,10 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
self.text_encoder,
use_lora,
"text_encoder",
lora_strength,
)
def forward(self, input):

View File

@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -64,10 +64,13 @@ class StencilPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)

View File

@@ -49,9 +49,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
def prepare_latents(
self,

View File

@@ -51,10 +51,13 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.is_fp32_vae = is_fp32_vae
def prepare_latents(

View File

@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE

View File

@@ -65,6 +65,7 @@ class StableDiffusionPipeline:
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_f32_vae: bool = False,
):
@@ -81,6 +82,7 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.lora_strength = lora_strength
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
@@ -647,6 +649,7 @@ class StableDiffusionPipeline:
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
lora_strength: float = 0.75,
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
@@ -682,6 +685,7 @@ class StableDiffusionPipeline:
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
lora_strength=lora_strength,
use_quantize=use_quantize,
)
@@ -692,12 +696,19 @@ class StableDiffusionPipeline:
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
scheduler,
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
stencils,
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
@@ -706,11 +717,14 @@ class StableDiffusionPipeline:
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
is_fp32_vae,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
return cls(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
# #####################################################
# Implements text embeddings with weights from prompts

View File

@@ -435,6 +435,13 @@ p.add_argument(
"file (~3 MB).",
)
p.add_argument(
"--lora_strength",
type=float,
default=1.0,
help="Strength (alpha) scaling factor to use when applying LoRA weights",
)
p.add_argument(
"--use_quantize",
type=str,

View File

@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from random import (
@@ -638,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
return converted_vae_checkpoint
def processLoRA(model, use_lora, splitting_prefix):
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight. Unsure if this is a safe assumption.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if ".alpha" in key or key in visited:
continue
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0,
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
@@ -678,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = (
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
)
weight_down = (
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
change = (
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora):
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
main_file_name = get_path_stem(use_lora)
@@ -733,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
print(f"updated unet weights manually from LoRA: {use_lora}")
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name):
def update_lora_weight(model, use_lora, model_name, lora_strength):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_")
print(f"updating CLIP weights from LoRA: {use_lora}")
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
@@ -898,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")

View File

@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
decode_base64_to_image,
get_model_from_request,
get_scheduler_from_request,
get_lora_params,
get_device,
GenerationInputData,
GenerationResponseData,
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
scheduler = get_scheduler_from_request(
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
)
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
print(
f"Prompt: {InputData.prompt}, "
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
@@ -270,7 +268,6 @@ def img2img_api(
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(InputData, "img2img")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
mask_image = (
@@ -308,8 +305,8 @@ def img2img_api(
use_stencil=InputData.use_stencil,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
@@ -358,7 +355,6 @@ def inpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "inpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.image)
mask = decode_base64_to_image(InputData.mask)
@@ -393,8 +389,8 @@ def inpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -448,7 +444,6 @@ def outpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "outpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -484,8 +479,8 @@ def outpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -531,7 +526,6 @@ def upscaler_api(
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
scheduler = get_scheduler_from_request(InputData, "upscaler")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -563,8 +557,8 @@ def upscaler_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)

View File

@@ -191,17 +191,6 @@ def get_scheduler_from_request(
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched

View File

@@ -55,3 +55,10 @@ def lora_changed(lora_file):
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
def lora_strength_changed(strength):
if strength > 1.0:
return gr.Number(elem_classes="value-out-of-range")
else:
return gr.Number(elem_classes="")

View File

@@ -244,6 +244,11 @@ footer {
padding-right: 8px;
}
/* number input value is out of range */
.value-out-of-range input[type="number"] {
color: red !important;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;

View File

@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -74,7 +77,7 @@ def img2img_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
@@ -141,9 +144,8 @@ def img2img_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -176,6 +178,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=stencils,
ondemand=ondemand,
)
@@ -228,6 +231,7 @@ def img2img_inf(
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -249,6 +253,7 @@ def img2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -806,28 +811,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -1013,7 +1015,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
resample_type,
@@ -1054,3 +1056,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -109,7 +112,7 @@ def inpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: int,
):
@@ -150,9 +153,8 @@ def inpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -171,6 +173,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -215,6 +218,7 @@ def inpaint_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -350,28 +354,25 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -558,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -622,3 +623,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -238,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
get_custom_vae_or_lora_weights(lora_weights, "lora"),
],
outputs=[std_output],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -4,7 +4,10 @@ import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -60,7 +63,7 @@ def outpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -100,9 +103,8 @@ def outpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -121,6 +123,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -163,6 +166,7 @@ def outpaint_inf(
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -296,28 +300,25 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -527,7 +528,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -556,3 +557,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -15,7 +15,10 @@ from apps.stable_diffusion.web.ui.utils import (
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -59,7 +62,7 @@ def txt2img_sdxl_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -105,9 +108,8 @@ def txt2img_sdxl_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -123,6 +125,7 @@ def txt2img_sdxl_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=None,
ondemand=ondemand,
)
@@ -171,6 +174,7 @@ def txt2img_sdxl_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
@@ -316,28 +320,25 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_sdxl_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_sdxl_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -539,7 +540,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -609,7 +610,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
@@ -624,7 +624,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
)
@@ -651,3 +650,11 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -18,7 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -44,7 +47,7 @@ all_gradio_labels = [
"prompt",
"negative_prompt",
"lora_weights",
"lora_hf_id",
"lora_strength",
"scheduler",
"save_metadata_to_png",
"save_metadata_to_json",
@@ -91,7 +94,7 @@ def txt2img_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
@@ -138,9 +141,8 @@ def txt2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -156,6 +158,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -207,6 +210,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -256,6 +260,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -288,6 +293,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -385,7 +391,7 @@ def load_settings():
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_hf_id", ""),
loaded_settings.get("lora_strength", args.lora_strength),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
@@ -495,28 +501,25 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value=default_settings.get("lora_hf_id"),
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=default_settings.get("lora_strength"),
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -736,7 +739,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
@@ -769,7 +772,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
@@ -813,7 +816,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
use_hiresfix,
@@ -856,7 +859,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
@@ -871,7 +873,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
lora_strength,
custom_vae,
],
)
@@ -902,3 +904,11 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -13,7 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -53,7 +56,7 @@ def upscaler_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -100,9 +103,8 @@ def upscaler_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -120,6 +122,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -159,6 +162,7 @@ def upscaler_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -318,28 +322,25 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -523,7 +524,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -552,3 +553,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -33,6 +33,7 @@ class Config:
width: int
device: str
use_lora: str
lora_strength: float
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
@@ -180,14 +181,16 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
def get_custom_vae_or_lora_weights(weights, model):
if weights == "None":
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
custom_weights = get_custom_model_pathfile(str(weights), model)
if os.path.isfile(custom_weights):
use_weight = custom_weights
else:
use_weight = weights
return use_weight

View File

@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
) -> tuple[str, float]:
lora_custom = ""
lora_strength = 1.0
if key in metadata:
lora_file = metadata[key]
split_metadata = metadata[key].split(":")
lora_file = split_metadata[0]
if len(split_metadata) == 2:
try:
lora_strength = float(split_metadata[1])
except ValueError:
pass
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
lora_custom = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
return lora_custom, lora_strength
def import_png_metadata(
@@ -150,7 +156,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
@@ -160,9 +165,10 @@ def import_png_metadata(
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
(
custom_lora,
custom_lora_strength,
) = find_lora_from_png_metadata("LoRA", metadata)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
@@ -177,12 +183,8 @@ def import_png_metadata(
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
if "LoRA" in metadata and not custom_lora:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
@@ -215,6 +217,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_lora_strength,
custom_vae,
)