[SDXL] Add SDXL pipeline to SHARK (#1941)

* [SDXL] Add SDXL pipeline to SHARK

-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
This commit is contained in:
Abhishek Varma
2023-12-02 14:45:15 +05:30
committed by GitHub
parent ff15fd74f6
commit a1b7110550
13 changed files with 1451 additions and 12 deletions

View File

@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from collections import defaultdict
from pathlib import Path
import torch
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
get_stencil_model_id,
update_lora_weight,
)
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
# These shapes are parameter dependent.
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
)
def shark_compile_after_ir(
module_name,
device,
vmfb_path,
precision,
ir_path=None,
):
if ir_path:
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
module = SharkInference(
mlir_module=ir_path,
device=device,
mlir_dialect="tm_tensor",
)
print(f"Will get extra flag for {module_name} and precision = {precision}")
path = module.save_module(
vmfb_path.parent.absolute(),
vmfb_path.stem,
extra_args=get_opt_flags(module_name, precision=precision),
)
print(f"Saved {module_name} vmfb at {path}")
module.load_module(path)
return module
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
name_split = extended_model_name.split("_")
if "vae" in model_name:
name_split[5] = "fp32"
extended_model_name_for_vmfb = "_".join(name_split)
extended_model_name_for_mlir = "_".join(name_split[:-1])
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
device=device,
mlir_dialect="tm_tensor",
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=[])
return shark_module, None
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
if not mlir_path.exists():
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
download_public_file(
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
return (
shark_compile_after_ir(
model_name, device, vmfb_path, precision, mlir_path
),
None,
)
return None, None
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -86,6 +156,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
@@ -93,6 +164,7 @@ class SharkifyStableDiffusionModel:
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
@@ -175,6 +247,7 @@ class SharkifyStableDiffusionModel:
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
@@ -343,6 +416,76 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["vae"], "vae", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -758,6 +901,93 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["unet"], "unet", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -805,6 +1035,78 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
if clip_index == 1:
extended_model_name = self.model_name["clip"]
model_name = "clip"
else:
extended_model_name = self.model_name["clip2"]
model_name = "clip2"
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
extended_model_name, f"clip", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -837,7 +1139,9 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -879,6 +1183,22 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
stencil_count = 0
@@ -893,7 +1213,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
@@ -904,7 +1224,7 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=model_id
)
except Exception as e:
print(e)
@@ -943,7 +1263,10 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")

View File

@@ -123,8 +123,8 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
def get_tokenizer(subfolder="tokenizer"):
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
args.hf_model_id, subfolder=subfolder
)
return tokenizer

View File

@@ -1,6 +1,9 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -0,0 +1,214 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
# imgs = self.decode_latents_sdxl(None)
# all_imgs.extend(imgs)
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -33,6 +33,8 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -63,6 +65,7 @@ class StableDiffusionPipeline:
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
@@ -106,6 +109,34 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -159,6 +190,179 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -186,6 +390,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -298,6 +503,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -306,6 +513,72 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
):
# return None
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=False,
)
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -356,6 +629,7 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
@@ -372,6 +646,7 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
@@ -503,6 +778,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()

View File

@@ -8,6 +8,15 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -179,6 +188,49 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {

View File

@@ -85,7 +85,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The height of the output image.",
)
@@ -93,7 +93,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The width of the output image.",
)

View File

@@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags

View File

@@ -109,6 +109,12 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
@@ -253,6 +259,8 @@ if __name__ == "__main__":
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
txt2img_sdxl_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:

View File

@@ -10,6 +10,13 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_web,

View File

@@ -0,0 +1,458 @@
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_sdxl_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
ondemand: bool,
repeatable_seeds: bool,
):
if precision != "fp16":
print("currently we support fp16 for SDXL")
precision = "fp16"
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = model_id
# if custom_vae != "None":
# args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = ""
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img_sdxl",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
# For SDXL we set max_length as 77.
print("Setting max_length = 77")
max_length = 77
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
global_obj.set_sd_obj(txt2img_sdxl_obj)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image-SDXL",
current_batch + 1,
batch_count,
batch_size,
)
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_sdxl_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-xl-base-1.0",
choices=[
"stabilityai/stable-diffusion-xl-base-1.0"
],
allow_custom_value=True,
scale=2,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=["DDIM"],
allow_custom_value=True,
visible=False,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
1024,
value=1024,
step=8,
label="Height",
visible=False,
)
width = gr.Slider(
1024,
value=1024,
step=8,
label="Width",
visible=False,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_sdxl_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"{t2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_sdxl_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
kwargs = dict(
fn=txt2img_sdxl_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_sdxl_custom_model,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)