Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)

This reverts commit 9f0a421764.
This commit is contained in:
Ean Garvey
2023-10-09 20:01:44 -05:00
committed by GitHub
parent 6e409bfb77
commit 2004d16945
10 changed files with 40 additions and 875 deletions

View File

@@ -1,9 +1,9 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
@@ -16,62 +16,31 @@ def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
if args.height == 1024:
assert (
args.width == 1024
), "currently we support only 1024x1024 image size via SDXL"
assert args.precision == "fp16", "currently we support fp16 for SDXL"
# For SDXL we set max_length as 77.
args.max_length = 77
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
else:
assert (
args.height <= 768 and args.width <= 768
), "height/width not in supported range"
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):

View File

@@ -9,7 +9,6 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from transformers import CLIPTextModel
from collections import defaultdict
from pathlib import Path
import torch
@@ -53,10 +53,6 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -88,7 +84,6 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
is_sdxl: bool = False,
use_stencil: str = None,
use_lora: str = "",
use_quantize: str = None,
@@ -96,14 +91,8 @@ class SharkifyStableDiffusionModel:
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height
self.width = width
if is_sdxl:
# We need to scale down the height/width by vae_scale_factor, which
# happens to be 8 in this case.
self.height = height // 8
self.width = width // 8
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
@@ -185,7 +174,6 @@ class SharkifyStableDiffusionModel:
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
@@ -353,71 +341,6 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
vae_name_split = self.model_name["vae"].split("_")
vae_name_split[5] = "fp32"
extended_model_name = "_".join(vae_name_split)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -764,85 +687,6 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -890,68 +734,6 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
if clip_index == 1:
model_name = self.model_name["clip"]
else:
model_name = self.model_name["clip2"]
save_dir = os.path.join(self.sharktank_dir, model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -984,9 +766,7 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -1028,22 +808,6 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
@@ -1055,7 +819,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large, base_model=self.base_model_id
model, use_large=use_large
)
else:
for model_id in unet_inputs:
@@ -1066,7 +830,7 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large, base_model=model_id
model, use_large=use_large
)
except Exception as e:
print(e)
@@ -1105,10 +869,7 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")

View File

@@ -123,8 +123,8 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer(subfolder="tokenizer"):
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder=subfolder
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -1,9 +1,6 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -1,212 +0,0 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -33,7 +33,6 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -64,7 +63,6 @@ class StableDiffusionPipeline:
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
@@ -108,34 +106,6 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -190,177 +160,6 @@ class StableDiffusionPipeline:
del self.vae
self.vae = None
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
@@ -507,69 +306,6 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
):
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=False,
)
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -619,7 +355,6 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
@@ -636,7 +371,6 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
is_sdxl=is_sdxl,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,

View File

@@ -8,15 +8,6 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -188,49 +179,6 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {

View File

@@ -83,7 +83,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 1025, 8),
choices=range(128, 769, 8),
help="The height of the output image.",
)
@@ -91,7 +91,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 1025, 8),
choices=range(128, 769, 8),
help="The width of the output image.",
)

View File

@@ -22,7 +22,6 @@ from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
@@ -160,37 +159,8 @@ def txt2img_inf(
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if height == 1024:
assert (
width == 1024
), "currently we support only 1024x1024 image size via SDXL"
assert precision == "fp16", "currently we support fp16 for SDXL"
# For SDXL we set max_length as 77.
max_length = 77
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
else:
assert (
height <= 768 and width <= 768
), "height/width not in supported range"
txt2img_obj = Text2ImagePipeline.from_pretrained(
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
@@ -198,18 +168,17 @@ def txt2img_inf(
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=height,
width=width,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
global_obj.set_sd_obj(txt2img_obj)
)
global_obj.set_sd_scheduler(scheduler)
@@ -533,15 +502,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
with gr.Row():
height = gr.Slider(
128,
1024,
384,
768,
value=args.height,
step=8,
label="Height",
)
width = gr.Slider(
128,
1024,
384,
768,
value=args.width,
step=8,
label="Width",