Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)

This reverts commit 9f0a421764.
This commit is contained in:
Ean Garvey
2023-10-09 20:01:44 -05:00
committed by GitHub
parent 6e409bfb77
commit 2004d16945
10 changed files with 40 additions and 875 deletions

View File

@@ -1,9 +1,9 @@
import torch import torch
import transformers
import time import time
from apps.stable_diffusion.src import ( from apps.stable_diffusion.src import (
args, args,
Text2ImagePipeline, Text2ImagePipeline,
Text2ImageSDXLPipeline,
get_schedulers, get_schedulers,
set_init_device_flags, set_init_device_flags,
utils, utils,
@@ -16,62 +16,31 @@ def main():
if args.clear_all: if args.clear_all:
clear_all() clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
dtype = torch.float32 if args.precision == "fp32" else torch.half dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark") cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags() set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id) schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler] scheduler_obj = schedulers[args.scheduler]
seed = args.seed seed = args.seed
if args.height == 1024: txt2img_obj = Text2ImagePipeline.from_pretrained(
assert ( scheduler=scheduler_obj,
args.width == 1024 import_mlir=args.import_mlir,
), "currently we support only 1024x1024 image size via SDXL" model_id=args.hf_model_id,
assert args.precision == "fp16", "currently we support fp16 for SDXL" ckpt_loc=args.ckpt_loc,
# For SDXL we set max_length as 77. precision=args.precision,
args.max_length = 77 max_length=args.max_length,
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained( batch_size=args.batch_size,
scheduler=scheduler_obj, height=args.height,
import_mlir=args.import_mlir, width=args.width,
model_id=args.hf_model_id, use_base_vae=args.use_base_vae,
ckpt_loc=args.ckpt_loc, use_tuned=args.use_tuned,
precision=args.precision, custom_vae=args.custom_vae,
max_length=args.max_length, low_cpu_mem_usage=args.low_cpu_mem_usage,
batch_size=args.batch_size, debug=args.import_debug if args.import_mlir else False,
height=args.height, use_lora=args.use_lora,
width=args.width, use_quantize=args.use_quantize,
use_base_vae=args.use_base_vae, ondemand=args.ondemand,
use_tuned=args.use_tuned, )
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
else:
assert (
args.height <= 768 and args.width <= 768
), "height/width not in supported range"
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count): for current_batch in range(args.batch_count):

View File

@@ -9,7 +9,6 @@ from apps.stable_diffusion.src.utils import (
) )
from apps.stable_diffusion.src.pipelines import ( from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline, Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline, Image2ImagePipeline,
InpaintPipeline, InpaintPipeline,
OutpaintPipeline, OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection from transformers import CLIPTextModel
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import torch import torch
@@ -53,10 +53,6 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val)) new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]: elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val)) new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else: else:
new_shape.append(shape[i]) new_shape.append(shape[i])
return new_shape return new_shape
@@ -88,7 +84,6 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True, generate_vmfb: bool = True,
is_inpaint: bool = False, is_inpaint: bool = False,
is_upscaler: bool = False, is_upscaler: bool = False,
is_sdxl: bool = False,
use_stencil: str = None, use_stencil: str = None,
use_lora: str = "", use_lora: str = "",
use_quantize: str = None, use_quantize: str = None,
@@ -96,14 +91,8 @@ class SharkifyStableDiffusionModel:
): ):
self.check_params(max_len, width, height) self.check_params(max_len, width, height)
self.max_len = max_len self.max_len = max_len
self.is_sdxl = is_sdxl self.height = height // 8
self.height = height self.width = width // 8
self.width = width
if is_sdxl:
# We need to scale down the height/width by vae_scale_factor, which
# happens to be 8 in this case.
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size self.batch_size = batch_size
self.custom_weights = custom_weights self.custom_weights = custom_weights
self.use_quantize = use_quantize self.use_quantize = use_quantize
@@ -185,7 +174,6 @@ class SharkifyStableDiffusionModel:
model_name = {} model_name = {}
sub_model_list = [ sub_model_list = [
"clip", "clip",
"clip2",
"unet", "unet",
"unet512", "unet512",
"stencil_unet", "stencil_unet",
@@ -353,71 +341,6 @@ class SharkifyStableDiffusionModel:
) )
return shark_vae, vae_mlir return shark_vae, vae_mlir
def get_vae_sdxl(self):
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
vae_name_split = self.model_name["vae"].split("_")
vae_name_split[5] = "fp32"
extended_model_name = "_".join(vae_name_split)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False): def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module): class ControlledUnetModel(torch.nn.Module):
def __init__( def __init__(
@@ -764,85 +687,6 @@ class SharkifyStableDiffusionModel:
) )
return shark_unet, unet_mlir return shark_unet, unet_mlir
def get_unet_sdxl(self):
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self): def get_clip(self):
class CLIPText(torch.nn.Module): class CLIPText(torch.nn.Module):
def __init__( def __init__(
@@ -890,68 +734,6 @@ class SharkifyStableDiffusionModel:
) )
return shark_clip, clip_mlir return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
if clip_index == 1:
model_name = self.model_name["clip"]
else:
model_name = self.model_name["clip2"]
save_dir = os.path.join(self.sharktank_dir, model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self): def process_custom_vae(self):
custom_vae = self.custom_vae.lower() custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")): if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -984,9 +766,7 @@ class SharkifyStableDiffusionModel:
} }
return vae_dict return vae_dict
def compile_unet_variants(self, model, use_large=False, base_model=""): def compile_unet_variants(self, model, use_large=False):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet": if model == "unet":
if self.is_upscaler: if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large) return self.get_unet_upscaler(use_large=use_large)
@@ -1028,22 +808,6 @@ class SharkifyStableDiffusionModel:
except Exception as e: except Exception as e:
sys.exit(e) sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False): def unet(self, use_large=False):
try: try:
model = "stencil_unet" if self.use_stencil is not None else "unet" model = "stencil_unet" if self.use_stencil is not None else "unet"
@@ -1055,7 +819,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id] unet_inputs[self.base_model_id]
) )
compiled_unet, unet_mlir = self.compile_unet_variants( compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large, base_model=self.base_model_id model, use_large=use_large
) )
else: else:
for model_id in unet_inputs: for model_id in unet_inputs:
@@ -1066,7 +830,7 @@ class SharkifyStableDiffusionModel:
try: try:
compiled_unet, unet_mlir = self.compile_unet_variants( compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large, base_model=model_id model, use_large=use_large
) )
except Exception as e: except Exception as e:
print(e) print(e)
@@ -1105,10 +869,7 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae is_base_vae = self.base_vae
if self.is_upscaler: if self.is_upscaler:
self.base_vae = True self.base_vae = True
if self.is_sdxl: compiled_vae, vae_mlir = self.get_vae()
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae") check_compilation(compiled_vae, "Vae")

View File

@@ -123,8 +123,8 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags) return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer(subfolder="tokenizer"): def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder=subfolder args.hf_model_id, subfolder="tokenizer"
) )
return tokenizer return tokenizer

View File

@@ -1,9 +1,6 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import ( from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline, Text2ImagePipeline,
) )
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import ( from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline, Image2ImagePipeline,
) )

View File

@@ -1,212 +0,0 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -33,7 +33,6 @@ from apps.stable_diffusion.src.utils import (
end_profiling, end_profiling,
) )
import sys import sys
from typing import List, Optional
SD_STATE_IDLE = "idle" SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel" SD_STATE_CANCEL = "cancel"
@@ -64,7 +63,6 @@ class StableDiffusionPipeline:
): ):
self.vae = None self.vae = None
self.text_encoder = None self.text_encoder = None
self.text_encoder_2 = None
self.unet = None self.unet = None
self.unet_512 = None self.unet_512 = None
self.model_max_length = 77 self.model_max_length = 77
@@ -108,34 +106,6 @@ class StableDiffusionPipeline:
del self.text_encoder del self.text_encoder
self.text_encoder = None self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self): def load_unet(self):
if self.unet is not None: if self.unet is not None:
return return
@@ -190,177 +160,6 @@ class StableDiffusionPipeline:
del self.vae del self.vae
self.vae = None self.vae = None
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length): def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings # Tokenize text and get embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
@@ -507,69 +306,6 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0) all_latents = torch.cat(latent_history, dim=0)
return all_latents return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=False,
)
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@@ -619,7 +355,6 @@ class StableDiffusionPipeline:
"OutpaintPipeline", "OutpaintPipeline",
] ]
is_upscaler = cls.__name__ in ["UpscalerPipeline"] is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel( sd_model = SharkifyStableDiffusionModel(
model_id, model_id,
@@ -636,7 +371,6 @@ class StableDiffusionPipeline:
debug=debug, debug=debug,
is_inpaint=is_inpaint, is_inpaint=is_inpaint,
is_upscaler=is_upscaler, is_upscaler=is_upscaler,
is_sdxl=is_sdxl,
use_stencil=use_stencil, use_stencil=use_stencil,
use_lora=use_lora, use_lora=use_lora,
use_quantize=use_quantize, use_quantize=use_quantize,

View File

@@ -8,15 +8,6 @@
"dtype":"i64" "dtype":"i64"
} }
}, },
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": { "vae_encode": {
"image" : { "image" : {
"shape" : [ "shape" : [
@@ -188,49 +179,6 @@
"shape": [2], "shape": [2],
"dtype": "i64" "dtype": "i64"
} }
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
} }
}, },
"stencil_adaptor": { "stencil_adaptor": {

View File

@@ -83,7 +83,7 @@ p.add_argument(
"--height", "--height",
type=int, type=int,
default=512, default=512,
choices=range(128, 1025, 8), choices=range(128, 769, 8),
help="The height of the output image.", help="The height of the output image.",
) )
@@ -91,7 +91,7 @@ p.add_argument(
"--width", "--width",
type=int, type=int,
default=512, default=512,
choices=range(128, 1025, 8), choices=range(128, 769, 8),
help="The width of the output image.", help="The width of the output image.",
) )

View File

@@ -22,7 +22,6 @@ from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import ( from apps.stable_diffusion.src import (
args, args,
Text2ImagePipeline, Text2ImagePipeline,
Text2ImageSDXLPipeline,
get_schedulers, get_schedulers,
set_init_device_flags, set_init_device_flags,
utils, utils,
@@ -160,37 +159,8 @@ def txt2img_inf(
) )
global_obj.set_schedulers(get_schedulers(model_id)) global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler) scheduler_obj = global_obj.get_scheduler(scheduler)
if height == 1024: global_obj.set_sd_obj(
assert ( Text2ImagePipeline.from_pretrained(
width == 1024
), "currently we support only 1024x1024 image size via SDXL"
assert precision == "fp16", "currently we support fp16 for SDXL"
# For SDXL we set max_length as 77.
max_length = 77
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
else:
assert (
height <= 768 and width <= 768
), "height/width not in supported range"
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj, scheduler=scheduler_obj,
import_mlir=args.import_mlir, import_mlir=args.import_mlir,
model_id=args.hf_model_id, model_id=args.hf_model_id,
@@ -198,18 +168,17 @@ def txt2img_inf(
precision=args.precision, precision=args.precision,
max_length=args.max_length, max_length=args.max_length,
batch_size=args.batch_size, batch_size=args.batch_size,
height=height, height=args.height,
width=width, width=args.width,
use_base_vae=args.use_base_vae, use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned, use_tuned=args.use_tuned,
custom_vae=args.custom_vae, custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage, low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False, debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora, use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand, ondemand=args.ondemand,
) )
global_obj.set_sd_obj(txt2img_obj) )
global_obj.set_sd_scheduler(scheduler) global_obj.set_sd_scheduler(scheduler)
@@ -533,15 +502,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
) )
with gr.Row(): with gr.Row():
height = gr.Slider( height = gr.Slider(
128, 384,
1024, 768,
value=args.height, value=args.height,
step=8, step=8,
label="Height", label="Height",
) )
width = gr.Slider( width = gr.Slider(
128, 384,
1024, 768,
value=args.width, value=args.width,
step=8, step=8,
label="Width", label="Width",