mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[SDXL] Add SDXL pipeline to SHARK (#1941)
* [SDXL] Add SDXL pipeline to SHARK -- This commit adds SDXL pipeline to SHARK. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> * (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags. --------- Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
This commit is contained in:
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
args.precision = "fp16"
|
||||
args.height = 1024
|
||||
args.width = 1024
|
||||
args.max_length = 77
|
||||
args.scheduler = "DDIM"
|
||||
print(
|
||||
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||
from transformers import CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import torch
|
||||
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
# These shapes are parameter dependent.
|
||||
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape.append(math.ceil(height / div_val))
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(math.ceil(width / div_val))
|
||||
elif "+" in shape[i]:
|
||||
# Currently this case only hits for SDXL. So, in case any other
|
||||
# case requires this operator, change this.
|
||||
new_shape.append(height + width)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
|
||||
)
|
||||
|
||||
|
||||
def shark_compile_after_ir(
|
||||
module_name,
|
||||
device,
|
||||
vmfb_path,
|
||||
precision,
|
||||
ir_path=None,
|
||||
):
|
||||
if ir_path:
|
||||
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
|
||||
|
||||
module = SharkInference(
|
||||
mlir_module=ir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"Will get extra flag for {module_name} and precision = {precision}")
|
||||
path = module.save_module(
|
||||
vmfb_path.parent.absolute(),
|
||||
vmfb_path.stem,
|
||||
extra_args=get_opt_flags(module_name, precision=precision),
|
||||
)
|
||||
print(f"Saved {module_name} vmfb at {path}")
|
||||
module.load_module(path)
|
||||
return module
|
||||
|
||||
|
||||
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
|
||||
name_split = extended_model_name.split("_")
|
||||
if "vae" in model_name:
|
||||
name_split[5] = "fp32"
|
||||
extended_model_name_for_vmfb = "_".join(name_split)
|
||||
extended_model_name_for_mlir = "_".join(name_split[:-1])
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
|
||||
if "vulkan" in device:
|
||||
_device = args.iree_vulkan_target_triple
|
||||
_device = _device.replace("-", "_")
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
|
||||
if vmfb_path.exists():
|
||||
shark_module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=[])
|
||||
return shark_module, None
|
||||
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
|
||||
if not mlir_path.exists():
|
||||
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
return (
|
||||
shark_compile_after_ir(
|
||||
model_name, device, vmfb_path, precision, mlir_path
|
||||
),
|
||||
None,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,6 +156,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb: bool = True,
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
is_sdxl: bool = False,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
@@ -93,6 +164,7 @@ class SharkifyStableDiffusionModel:
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
self.is_sdxl = is_sdxl
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
@@ -175,6 +247,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_name = {}
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
"clip2",
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
@@ -343,6 +416,76 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_vae_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["vae"], "vae", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
base_vae=self.base_vae,
|
||||
custom_vae=self.custom_vae,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
|
||||
def forward(self, latents):
|
||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
||||
0
|
||||
]
|
||||
return image
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
||||
# pipeline.
|
||||
is_f16 = False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
shark_vae, vae_mlir = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
extended_model_name=self.model_name["vae"],
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="vae",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -758,6 +901,93 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["unet"], "unet", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
):
|
||||
if args.attention_slicing.isdigit():
|
||||
self.unet.set_attention_slice(
|
||||
int(args.attention_slicing)
|
||||
)
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
text_embeds,
|
||||
time_ids,
|
||||
guidance_scale,
|
||||
):
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeds,
|
||||
"time_ids": time_ids,
|
||||
}
|
||||
noise_pred = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, True, True, True]
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -805,6 +1035,78 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def get_clip_sdxl(self, clip_index=1):
|
||||
if clip_index == 1:
|
||||
extended_model_name = self.model_name["clip"]
|
||||
model_name = "clip"
|
||||
else:
|
||||
extended_model_name = self.model_name["clip2"]
|
||||
model_name = "clip2"
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
extended_model_name, f"clip", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
clip_index=1,
|
||||
):
|
||||
super().__init__()
|
||||
if clip_index == 1:
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.text_encoder = (
|
||||
CLIPTextModelWithProjection.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_2",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
prompt_embeds = self.text_encoder(
|
||||
input,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
clip_model = CLIPText(
|
||||
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_clip, clip_mlir = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
extended_model_name=extended_model_name,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="clip",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def process_custom_vae(self):
|
||||
custom_vae = self.custom_vae.lower()
|
||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||
@@ -837,7 +1139,9 @@ class SharkifyStableDiffusionModel:
|
||||
}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
def compile_unet_variants(self, model, use_large=False, base_model=""):
|
||||
if self.is_sdxl:
|
||||
return self.get_unet_sdxl()
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
@@ -879,6 +1183,22 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def sdxl_clip(self):
|
||||
try:
|
||||
self.inputs["clip"] = self.get_input_info_for(
|
||||
base_models["sdxl_clip"]
|
||||
)
|
||||
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
|
||||
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
|
||||
|
||||
check_compilation(compiled_clip, "Clip")
|
||||
check_compilation(compiled_clip, "Clip2")
|
||||
if self.return_mlir:
|
||||
return clip_mlir, clip_mlir2
|
||||
return compiled_clip, compiled_clip2
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
stencil_count = 0
|
||||
@@ -893,7 +1213,7 @@ class SharkifyStableDiffusionModel:
|
||||
unet_inputs[self.base_model_id]
|
||||
)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=self.base_model_id
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
@@ -904,7 +1224,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=model_id
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -943,7 +1263,10 @@ class SharkifyStableDiffusionModel:
|
||||
is_base_vae = self.base_vae
|
||||
if self.is_upscaler:
|
||||
self.base_vae = True
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
if self.is_sdxl:
|
||||
compiled_vae, vae_mlir = self.get_vae_sdxl()
|
||||
else:
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
self.base_vae = is_base_vae
|
||||
|
||||
check_compilation(compiled_vae, "Vae")
|
||||
|
||||
@@ -123,8 +123,8 @@ def get_clip():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
def get_tokenizer(subfolder="tokenizer"):
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder="tokenizer"
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
||||
Text2ImageSDXLPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype
|
||||
):
|
||||
add_time_ids = list(
|
||||
original_size + crops_coords_top_left + target_size
|
||||
)
|
||||
|
||||
# self.unet.config.addition_time_embed_dim IS 256.
|
||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
||||
expected_add_embed_dim = 2816
|
||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings.
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt_sdxl(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=neg_prompts,
|
||||
)
|
||||
|
||||
# Prepare timesteps.
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare added time ids & embeddings.
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds
|
||||
add_text_embeds = add_text_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.to(dtype)
|
||||
|
||||
# Get Image latents.
|
||||
latents = self.produce_img_latents_sdxl(
|
||||
init_latents,
|
||||
timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
# imgs = self.decode_latents_sdxl(None)
|
||||
# all_imgs.extend(imgs)
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -33,6 +33,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
import sys
|
||||
import gc
|
||||
from typing import List, Optional
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
@@ -63,6 +65,7 @@ class StableDiffusionPipeline:
|
||||
):
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
@@ -106,6 +109,34 @@ class StableDiffusionPipeline:
|
||||
del self.text_encoder
|
||||
self.text_encoder = None
|
||||
|
||||
def load_clip_sdxl(self):
|
||||
if self.text_encoder and self.text_encoder_2:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
if not self.import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. "
|
||||
"Importing MLIR anyways."
|
||||
)
|
||||
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
|
||||
else:
|
||||
try:
|
||||
# TODO: Fix this for SDXL
|
||||
self.text_encoder = get_clip()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
) = self.sd_model.sdxl_clip()
|
||||
|
||||
def unload_clip_sdxl(self):
|
||||
del self.text_encoder, self.text_encoder_2
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
|
||||
def load_unet(self):
|
||||
if self.unet is not None:
|
||||
return
|
||||
@@ -159,6 +190,179 @@ class StableDiffusionPipeline:
|
||||
def unload_vae(self):
|
||||
del self.vae
|
||||
self.vae = None
|
||||
gc.collect()
|
||||
|
||||
def encode_prompt_sdxl(
|
||||
self,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
self.tokenizer_2 = get_tokenizer("tokenizer_2")
|
||||
self.load_clip_sdxl()
|
||||
tokenizers = (
|
||||
[self.tokenizer, self.tokenizer_2]
|
||||
if self.tokenizer is not None
|
||||
else [self.tokenizer_2]
|
||||
)
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2]
|
||||
if self.text_encoder is not None
|
||||
else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(
|
||||
prompts, tokenizers, text_encoders
|
||||
):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
print(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = text_encoder("forward", (text_input_ids,))
|
||||
prompt_embeds = torch.from_numpy(text_encoder_output[0])
|
||||
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = (
|
||||
negative_prompt is None
|
||||
and self.config.force_zeros_for_empty_prompt
|
||||
)
|
||||
if (
|
||||
do_classifier_free_guidance
|
||||
and negative_prompt_embeds is None
|
||||
and zero_out_negative_prompt
|
||||
):
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(
|
||||
pooled_prompt_embeds
|
||||
)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt
|
||||
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(
|
||||
negative_prompt
|
||||
):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(
|
||||
uncond_tokens, tokenizers, text_encoders
|
||||
):
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_encoder_output = text_encoder(
|
||||
"forward", (uncond_input.input_ids,)
|
||||
)
|
||||
negative_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[0]
|
||||
)
|
||||
negative_pooled_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[1]
|
||||
)
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(
|
||||
negative_prompt_embeds_list, dim=-1
|
||||
)
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_clip_sdxl()
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -186,6 +390,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings
|
||||
@@ -298,6 +503,8 @@ class StableDiffusionPipeline:
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -306,6 +513,72 @@ class StableDiffusionPipeline:
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def produce_img_latents_sdxl(
|
||||
self,
|
||||
latents,
|
||||
total_timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
):
|
||||
# return None
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
extra_step_kwargs = {"generator": None}
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
).to(dtype)
|
||||
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents_sdxl(self, latents):
|
||||
latents = latents.to(torch.float32)
|
||||
images = self.vae("forward", (latents,))
|
||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -356,6 +629,7 @@ class StableDiffusionPipeline:
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
|
||||
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
@@ -372,6 +646,7 @@ class StableDiffusionPipeline:
|
||||
debug=debug,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
is_sdxl=is_sdxl,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
@@ -503,6 +778,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
|
||||
@@ -8,6 +8,15 @@
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"sdxl_clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"1*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
@@ -179,6 +188,49 @@
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adaptor": {
|
||||
|
||||
@@ -85,7 +85,7 @@ p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
|
||||
@@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
|
||||
@@ -109,6 +109,12 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# SDXL
|
||||
txt2img_sdxl_inf,
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_status,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
@@ -253,6 +259,8 @@ if __name__ == "__main__":
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
|
||||
txt2img_sdxl_web.render()
|
||||
|
||||
actual_port = app.usable_port()
|
||||
if actual_port != args.server_port:
|
||||
|
||||
@@ -10,6 +10,13 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
|
||||
txt2img_sdxl_inf,
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_status,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_web,
|
||||
|
||||
458
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
458
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
@@ -0,0 +1,458 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
def txt2img_sdxl_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = model_id
|
||||
|
||||
# if custom_vae != "None":
|
||||
# args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = ""
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"txt2img_sdxl",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
# For SDXL we set max_length as 77.
|
||||
print("Setting max_length = 77")
|
||||
max_length = 77
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
global_obj.set_sd_obj(txt2img_sdxl_obj)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image-SDXL",
|
||||
current_batch + 1,
|
||||
batch_count,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_sdxl_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=[
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
],
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=["DDIM"],
|
||||
allow_custom_value=True,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
label="Height",
|
||||
visible=False,
|
||||
)
|
||||
width = gr.Slider(
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
label="Width",
|
||||
visible=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
txt2img_sdxl_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_sdxl_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_sdxl_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_sdxl_custom_model,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_sdxl_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
Reference in New Issue
Block a user