mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Various bugfixes and SDXL additions.
This commit is contained in:
@@ -436,24 +436,40 @@ class SharkifyStableDiffusionModel:
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
print(f"Loading default vae, with target {model_id}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
print(f"Loading custom vae, with target {custom_vae}")
|
||||
if os.path.exists(custom_vae):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
custom_vae = "/".join(
|
||||
[
|
||||
custom_vae.split("/")[-2].split("\\")[-1],
|
||||
custom_vae.split("/")[-1],
|
||||
]
|
||||
)
|
||||
print("Using hub to get custom vae")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
print(f"Loading custom vae, with target {custom_vae}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, latents):
|
||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
||||
@@ -465,7 +481,12 @@ class SharkifyStableDiffusionModel:
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
||||
# pipeline.
|
||||
is_f16 = False
|
||||
if not self.custom_vae:
|
||||
is_f16 = False
|
||||
elif "16" in self.custom_vae:
|
||||
is_f16 = True
|
||||
else:
|
||||
is_f16 = False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -158,9 +158,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
|
||||
@@ -16,7 +16,10 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
@@ -38,6 +41,7 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
@@ -48,8 +52,10 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -206,7 +212,9 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
# imgs = self.decode_latents_sdxl(None)
|
||||
# all_imgs.extend(imgs)
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
@@ -62,6 +62,7 @@ class StableDiffusionPipeline:
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
@@ -77,6 +78,7 @@ class StableDiffusionPipeline:
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
@@ -332,7 +334,7 @@ class StableDiffusionPipeline:
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -569,11 +571,15 @@ class StableDiffusionPipeline:
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents_sdxl(self, latents):
|
||||
latents = latents.to(torch.float32)
|
||||
def decode_latents_sdxl(self, latents, is_fp32_vae):
|
||||
# latents are in unet dtype here so switch if we want to use fp32
|
||||
if is_fp32_vae:
|
||||
print("Casting latents to float32 for VAE")
|
||||
latents = latents.to(torch.float32)
|
||||
images = self.vae("forward", (latents,))
|
||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
||||
|
||||
@@ -666,6 +672,17 @@ class StableDiffusionPipeline:
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
return cls(
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
# #####################################################
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from diffusers import (
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,6 +85,12 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerAncestralDiscrete"
|
||||
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
@@ -101,4 +108,5 @@ def get_schedulers(model_id):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile()
|
||||
return schedulers
|
||||
|
||||
@@ -7,6 +7,7 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -27,6 +28,10 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
@@ -35,6 +40,10 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
@@ -152,3 +161,144 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = "0",
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
|
||||
def compile(self):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_ancestral_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_ancestral_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_ancestral_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_ancestral_step_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
@@ -11,8 +11,9 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
predefined_sdxl_models,
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
@@ -50,17 +51,17 @@ def txt2img_sdxl_inf(
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
@@ -71,6 +72,10 @@ def txt2img_sdxl_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -93,13 +98,15 @@ def txt2img_sdxl_inf(
|
||||
else:
|
||||
args.hf_model_id = model_id
|
||||
|
||||
# if custom_vae != "None":
|
||||
# args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
if custom_vae:
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = ""
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -144,31 +151,29 @@ def txt2img_sdxl_inf(
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
# For SDXL we set max_length as 77.
|
||||
print("Setting max_length = 77")
|
||||
max_length = 77
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
)
|
||||
global_obj.set_sd_obj(txt2img_sdxl_obj)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
@@ -239,7 +244,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_sdxl_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
@@ -247,12 +252,39 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=[
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
],
|
||||
choices=predefined_sdxl_models
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
t2i_sdxl_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_vae_info = (
|
||||
f"VAE Path: {t2i_sdxl_vae_info}"
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"VAE Models",
|
||||
info=t2i_sdxl_vae_info,
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
choices=[
|
||||
"None",
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_sdxl_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
@@ -267,16 +299,49 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_sdxl_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_sdxl_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=["DDIM"],
|
||||
allow_custom_value=True,
|
||||
visible=False,
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"SharkEulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
allow_custom_value=False,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -291,18 +356,22 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
step=512,
|
||||
label="Height",
|
||||
visible=False,
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
step=512,
|
||||
label="Width",
|
||||
visible=False,
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -315,7 +384,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
value=77,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
@@ -394,7 +463,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_model_info}\n"
|
||||
value=f"{t2i_sdxl_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
@@ -429,11 +498,14 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_sdxl_custom_model,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -456,3 +528,51 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
txt2img_sdxl_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
txt2img_sdxl_custom_model.select(
|
||||
fn=set_model_default_configs,
|
||||
inputs=[
|
||||
txt2img_sdxl_custom_model,
|
||||
],
|
||||
outputs=[
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
width,
|
||||
height,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -64,9 +64,11 @@ scheduler_list_cpu_only = [
|
||||
"DPMSolverSinglestep",
|
||||
"DDPM",
|
||||
"HeunDiscrete",
|
||||
"LCMScheduler",
|
||||
]
|
||||
scheduler_list = scheduler_list_cpu_only + [
|
||||
"SharkEulerDiscrete",
|
||||
"SharkEulerAncestralDiscrete",
|
||||
]
|
||||
|
||||
predefined_models = [
|
||||
@@ -87,6 +89,10 @@ predefined_paint_models = [
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
]
|
||||
predefined_sdxl_models = [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -140,6 +146,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "sdxl":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if any(x in val for x in ["XL", "xl", "Xl"])
|
||||
]
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
@@ -247,6 +259,63 @@ def cancel_sd():
|
||||
pass
|
||||
|
||||
|
||||
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
|
||||
import gradio as gr
|
||||
|
||||
if jsonconfig:
|
||||
return get_config_from_json(jsonconfig)
|
||||
elif default_config_exists(model_ckpt_or_id):
|
||||
return default_configs[model_ckpt_or_id]
|
||||
# TODO: Use HF metadata to setup pipeline if available
|
||||
# elif is_valid_hf_id(model_ckpt_or_id):
|
||||
# return get_HF_default_configs(model_ckpt_or_id)
|
||||
else:
|
||||
# We don't have default metadata to setup a good config. Do not change configs.
|
||||
return [
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
]
|
||||
|
||||
|
||||
def get_config_from_json(model_ckpt_or_id, jsonconfig):
|
||||
# TODO: make this work properly. It is currently not user-exposed.
|
||||
cfgdata = json.load(jsonconfig)
|
||||
return [
|
||||
cfgdata["steps"],
|
||||
cfgdata["scheduler"],
|
||||
cfgdata["guidance_scale"],
|
||||
cfgdata["width"],
|
||||
cfgdata["height"],
|
||||
cfgdata["custom_vae"],
|
||||
]
|
||||
|
||||
|
||||
def default_config_exists(model_ckpt_or_id):
|
||||
if model_ckpt_or_id in [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable_diffusion-xl-base-1.0",
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
default_configs = {
|
||||
"stabilityai/sdxl-turbo": [1, "DDIM", 0, 512, 512, ""],
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": [
|
||||
50,
|
||||
"DDIM",
|
||||
7.5,
|
||||
512,
|
||||
512,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
],
|
||||
}
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
Reference in New Issue
Block a user