mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
SD/UI: Merge Lora Selection Boxes, Add LoRA Strength (#2052)
* Merges LoRA selection in the UI into a single selection, rather than one for LoRAs under ./models and another for Hugging Face Id * Add LoRA strength to UI and pipeline parameters. * Add a `--lora_strength` command line argument. * Bake LoRA strength into .vmfb naming when a LoRA is specified. * Use LoRA embedded alpha values and (up tensor dimension * LoRA strength) for final alpha when applying LoRA weights rather than a hardcoded value of 0.75 * Adds additional cases to the LoRA weight application that are present for weight application in the Kohya scripts. * Include lora strength when reading and writing png metadata. * Allow lora_strength to be set above 1.0 in the UI, so similar effects to the prior (overdriven alpha) implementation can be obtained.
This commit is contained in:
@@ -159,6 +159,7 @@ class SharkifyStableDiffusionModel:
|
||||
is_sdxl: bool = False,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
use_quantize: str = None,
|
||||
return_mlir: bool = False,
|
||||
):
|
||||
@@ -216,8 +217,14 @@ class SharkifyStableDiffusionModel:
|
||||
self.is_upscaler = is_upscaler
|
||||
self.stencils = [get_stencil_model_id(x) for x in stencils]
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.model_name = (
|
||||
self.model_name
|
||||
+ "_"
|
||||
+ get_path_stem(use_lora)
|
||||
+ f"@{int(lora_strength*100)}"
|
||||
)
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
self.debug = debug
|
||||
@@ -534,6 +541,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -542,7 +550,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -818,6 +828,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -826,7 +837,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if (
|
||||
@@ -1058,6 +1071,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
@@ -1067,7 +1081,10 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(
|
||||
self.text_encoder, use_lora, "text_encoder"
|
||||
self.text_encoder,
|
||||
use_lora,
|
||||
"text_encoder",
|
||||
lora_strength,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -64,10 +64,13 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
|
||||
@@ -49,9 +49,12 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
|
||||
@@ -51,10 +51,13 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
def prepare_latents(
|
||||
|
||||
@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class StableDiffusionPipeline:
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
@@ -81,6 +82,7 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
@@ -647,6 +649,7 @@ class StableDiffusionPipeline:
|
||||
stencils: list[str] = [],
|
||||
# stencil_images: list[Image] = []
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
@@ -682,6 +685,7 @@ class StableDiffusionPipeline:
|
||||
is_sdxl=is_sdxl,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
lora_strength=lora_strength,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
|
||||
@@ -692,12 +696,19 @@ class StableDiffusionPipeline:
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
)
|
||||
|
||||
if cls.__name__ == "StencilPipeline":
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
stencils,
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
@@ -706,11 +717,14 @@ class StableDiffusionPipeline:
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
# #####################################################
|
||||
# Implements text embeddings with weights from prompts
|
||||
|
||||
@@ -435,6 +435,13 @@ p.add_argument(
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--lora_strength",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Strength (alpha) scaling factor to use when applying LoRA weights",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
|
||||
from PIL import Image
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import (
|
||||
@@ -638,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
@dataclass
|
||||
class LoRAweight:
|
||||
up: torch.tensor
|
||||
down: torch.tensor
|
||||
mid: torch.tensor
|
||||
alpha: torch.float32 = 1.0
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
# gather the weights from the LoRA in a more convenient form, assumes
|
||||
# everything will have an up.weight. Unsure if this is a safe assumption.
|
||||
weight_dict: dict[str, LoRAweight] = {}
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
|
||||
stem = key.split("up.weight")[0]
|
||||
weight_key = stem.removesuffix(".lora_")
|
||||
weight_key = weight_key.removesuffix("_lora_")
|
||||
weight_key = weight_key.removesuffix(".lora_linear_layer.")
|
||||
|
||||
if weight_key not in weight_dict:
|
||||
weight_dict[weight_key] = LoRAweight(
|
||||
state_dict[f"{stem}up.weight"],
|
||||
state_dict[f"{stem}down.weight"],
|
||||
state_dict.get(f"{stem}mid.weight", None),
|
||||
state_dict[f"{weight_key}.alpha"]
|
||||
/ state_dict[f"{stem}up.weight"].shape[1]
|
||||
if f"{weight_key}.alpha" in state_dict
|
||||
else 1.0,
|
||||
)
|
||||
|
||||
# Directly update weight in model
|
||||
|
||||
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
|
||||
# and similar code in https://github.com/huggingface/diffusers/issues/3064
|
||||
|
||||
# TODO: handle mid weights (how do they even work?)
|
||||
for key, lora_weight in weight_dict.items():
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
@@ -678,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight = curr_layer.weight.data
|
||||
scale = lora_weight.alpha * lora_strength
|
||||
if len(weight.size()) == 2:
|
||||
if len(lora_weight.up.shape) == 4:
|
||||
weight_up = (
|
||||
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
change = (
|
||||
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
)
|
||||
else:
|
||||
change = torch.mm(lora_weight.up, lora_weight.down)
|
||||
elif lora_weight.down.size()[2:4] == (1, 1):
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
change = torch.nn.functional.conv2d(
|
||||
lora_weight.down.permute(1, 0, 2, 3),
|
||||
lora_weight.up,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
curr_layer.weight.data += change * scale
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
@@ -733,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
print(f"updated unet weights manually from LoRA: {use_lora}")
|
||||
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
def update_lora_weight(model, use_lora, model_name, lora_strength):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
return update_lora_weight_for_unet(model, use_lora, lora_strength)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
print(f"updating CLIP weights from LoRA: {use_lora}")
|
||||
return processLoRA(model, use_lora, "lora_te_", lora_strength)
|
||||
except:
|
||||
return None
|
||||
|
||||
@@ -898,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
|
||||
@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
@@ -270,7 +268,6 @@ def img2img_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
@@ -308,8 +305,8 @@ def img2img_api(
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
@@ -358,7 +355,6 @@ def inpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
@@ -393,8 +389,8 @@ def inpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -448,7 +444,6 @@ def outpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -484,8 +479,8 @@ def outpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -531,7 +526,6 @@ def upscaler_api(
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -563,8 +557,8 @@ def upscaler_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
@@ -191,17 +191,6 @@ def get_scheduler_from_request(
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
|
||||
@@ -55,3 +55,10 @@ def lora_changed(lora_file):
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
|
||||
|
||||
def lora_strength_changed(strength):
|
||||
if strength > 1.0:
|
||||
return gr.Number(elem_classes="value-out-of-range")
|
||||
else:
|
||||
return gr.Number(elem_classes="")
|
||||
|
||||
@@ -244,6 +244,11 @@ footer {
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
/* number input value is out of range */
|
||||
.value-out-of-range input[type="number"] {
|
||||
color: red !important;
|
||||
}
|
||||
|
||||
/* reduced animation load when generating */
|
||||
.generating {
|
||||
animation-play-state: paused !important;
|
||||
|
||||
@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -74,7 +77,7 @@ def img2img_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
@@ -141,9 +144,8 @@ def img2img_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -176,6 +178,7 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=stencils,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -228,6 +231,7 @@ def img2img_inf(
|
||||
stencils=stencils,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -249,6 +253,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -806,28 +811,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -1013,7 +1015,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -1054,3 +1056,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -109,7 +112,7 @@ def inpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: int,
|
||||
):
|
||||
@@ -150,9 +153,8 @@ def inpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -171,6 +173,7 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -215,6 +218,7 @@ def inpaint_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -350,28 +354,25 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=inpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -558,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -622,3 +623,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -238,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
max_length,
|
||||
training_images_dir,
|
||||
output_loc,
|
||||
get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
),
|
||||
get_custom_vae_or_lora_weights(lora_weights, "lora"),
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -4,7 +4,10 @@ import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -60,7 +63,7 @@ def outpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -100,9 +103,8 @@ def outpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -121,6 +123,7 @@ def outpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -163,6 +166,7 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -296,28 +300,25 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=outpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -527,7 +528,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -556,3 +557,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -59,7 +62,7 @@ def txt2img_sdxl_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -105,9 +108,8 @@ def txt2img_sdxl_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -123,6 +125,7 @@ def txt2img_sdxl_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -171,6 +174,7 @@ def txt2img_sdxl_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
@@ -316,28 +320,25 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_sdxl_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_sdxl_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -539,7 +540,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -609,7 +610,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
@@ -624,7 +624,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
@@ -651,3 +650,11 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -44,7 +47,7 @@ all_gradio_labels = [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"lora_weights",
|
||||
"lora_hf_id",
|
||||
"lora_strength",
|
||||
"scheduler",
|
||||
"save_metadata_to_png",
|
||||
"save_metadata_to_json",
|
||||
@@ -91,7 +94,7 @@ def txt2img_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
@@ -138,9 +141,8 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -156,6 +158,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -207,6 +210,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -256,6 +260,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -288,6 +293,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -385,7 +391,7 @@ def load_settings():
|
||||
loaded_settings.get("prompt", args.prompts[0]),
|
||||
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
|
||||
loaded_settings.get("lora_weights", "None"),
|
||||
loaded_settings.get("lora_hf_id", ""),
|
||||
loaded_settings.get("lora_strength", args.lora_strength),
|
||||
loaded_settings.get("scheduler", args.scheduler),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_png", args.write_metadata_to_png
|
||||
@@ -495,28 +501,25 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value=default_settings.get("lora_weights"),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value=default_settings.get("lora_hf_id"),
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=default_settings.get("lora_strength"),
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -736,7 +739,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
@@ -769,7 +772,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
@@ -813,7 +816,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
@@ -856,7 +859,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
@@ -871,7 +873,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
@@ -902,3 +904,11 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -53,7 +56,7 @@ def upscaler_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -100,9 +103,8 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -120,6 +122,7 @@ def upscaler_inf(
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -159,6 +162,7 @@ def upscaler_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -318,28 +322,25 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=upscaler_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -523,7 +524,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -552,3 +553,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ class Config:
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
lora_strength: float
|
||||
stencils: list[str]
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
@@ -180,14 +181,16 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
def get_custom_vae_or_lora_weights(weights, model):
|
||||
if weights == "None":
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
custom_weights = get_custom_model_pathfile(str(weights), model)
|
||||
if os.path.isfile(custom_weights):
|
||||
use_weight = custom_weights
|
||||
else:
|
||||
use_weight = weights
|
||||
|
||||
return use_weight
|
||||
|
||||
|
||||
|
||||
@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
) -> tuple[str, float]:
|
||||
lora_custom = ""
|
||||
lora_strength = 1.0
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
split_metadata = metadata[key].split(":")
|
||||
lora_file = split_metadata[0]
|
||||
if len(split_metadata) == 2:
|
||||
try:
|
||||
lora_strength = float(split_metadata[1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
lora_custom = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
return lora_custom, lora_strength
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
@@ -150,7 +156,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
@@ -160,9 +165,10 @@ def import_png_metadata(
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
(
|
||||
custom_lora,
|
||||
custom_lora_strength,
|
||||
) = find_lora_from_png_metadata("LoRA", metadata)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
@@ -177,12 +183,8 @@ def import_png_metadata(
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
if "LoRA" in metadata and not custom_lora:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
@@ -215,6 +217,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_lora_strength,
|
||||
custom_vae,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user