mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
This reverts commit 9f0a421764.
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import time
|
import time
|
||||||
from apps.stable_diffusion.src import (
|
from apps.stable_diffusion.src import (
|
||||||
args,
|
args,
|
||||||
Text2ImagePipeline,
|
Text2ImagePipeline,
|
||||||
Text2ImageSDXLPipeline,
|
|
||||||
get_schedulers,
|
get_schedulers,
|
||||||
set_init_device_flags,
|
set_init_device_flags,
|
||||||
utils,
|
utils,
|
||||||
@@ -16,62 +16,31 @@ def main():
|
|||||||
if args.clear_all:
|
if args.clear_all:
|
||||||
clear_all()
|
clear_all()
|
||||||
|
|
||||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
|
||||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||||
set_init_device_flags()
|
set_init_device_flags()
|
||||||
schedulers = get_schedulers(args.hf_model_id)
|
schedulers = get_schedulers(args.hf_model_id)
|
||||||
scheduler_obj = schedulers[args.scheduler]
|
scheduler_obj = schedulers[args.scheduler]
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
if args.height == 1024:
|
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||||
assert (
|
scheduler=scheduler_obj,
|
||||||
args.width == 1024
|
import_mlir=args.import_mlir,
|
||||||
), "currently we support only 1024x1024 image size via SDXL"
|
model_id=args.hf_model_id,
|
||||||
assert args.precision == "fp16", "currently we support fp16 for SDXL"
|
ckpt_loc=args.ckpt_loc,
|
||||||
# For SDXL we set max_length as 77.
|
precision=args.precision,
|
||||||
args.max_length = 77
|
max_length=args.max_length,
|
||||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
batch_size=args.batch_size,
|
||||||
scheduler=scheduler_obj,
|
height=args.height,
|
||||||
import_mlir=args.import_mlir,
|
width=args.width,
|
||||||
model_id=args.hf_model_id,
|
use_base_vae=args.use_base_vae,
|
||||||
ckpt_loc=args.ckpt_loc,
|
use_tuned=args.use_tuned,
|
||||||
precision=args.precision,
|
custom_vae=args.custom_vae,
|
||||||
max_length=args.max_length,
|
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||||
batch_size=args.batch_size,
|
debug=args.import_debug if args.import_mlir else False,
|
||||||
height=args.height,
|
use_lora=args.use_lora,
|
||||||
width=args.width,
|
use_quantize=args.use_quantize,
|
||||||
use_base_vae=args.use_base_vae,
|
ondemand=args.ondemand,
|
||||||
use_tuned=args.use_tuned,
|
)
|
||||||
custom_vae=args.custom_vae,
|
|
||||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
|
||||||
debug=args.import_debug if args.import_mlir else False,
|
|
||||||
use_lora=args.use_lora,
|
|
||||||
use_quantize=args.use_quantize,
|
|
||||||
ondemand=args.ondemand,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
args.height <= 768 and args.width <= 768
|
|
||||||
), "height/width not in supported range"
|
|
||||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
|
||||||
scheduler=scheduler_obj,
|
|
||||||
import_mlir=args.import_mlir,
|
|
||||||
model_id=args.hf_model_id,
|
|
||||||
ckpt_loc=args.ckpt_loc,
|
|
||||||
precision=args.precision,
|
|
||||||
max_length=args.max_length,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
height=args.height,
|
|
||||||
width=args.width,
|
|
||||||
use_base_vae=args.use_base_vae,
|
|
||||||
use_tuned=args.use_tuned,
|
|
||||||
custom_vae=args.custom_vae,
|
|
||||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
|
||||||
debug=args.import_debug if args.import_mlir else False,
|
|
||||||
use_lora=args.use_lora,
|
|
||||||
use_quantize=args.use_quantize,
|
|
||||||
ondemand=args.ondemand,
|
|
||||||
)
|
|
||||||
|
|
||||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||||
for current_batch in range(args.batch_count):
|
for current_batch in range(args.batch_count):
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from apps.stable_diffusion.src.utils import (
|
|||||||
)
|
)
|
||||||
from apps.stable_diffusion.src.pipelines import (
|
from apps.stable_diffusion.src.pipelines import (
|
||||||
Text2ImagePipeline,
|
Text2ImagePipeline,
|
||||||
Text2ImageSDXLPipeline,
|
|
||||||
Image2ImagePipeline,
|
Image2ImagePipeline,
|
||||||
InpaintPipeline,
|
InpaintPipeline,
|
||||||
OutpaintPipeline,
|
OutpaintPipeline,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
from transformers import CLIPTextModel
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
@@ -53,10 +53,6 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
|||||||
new_shape.append(math.ceil(height / div_val))
|
new_shape.append(math.ceil(height / div_val))
|
||||||
elif "width" in shape[i]:
|
elif "width" in shape[i]:
|
||||||
new_shape.append(math.ceil(width / div_val))
|
new_shape.append(math.ceil(width / div_val))
|
||||||
elif "+" in shape[i]:
|
|
||||||
# Currently this case only hits for SDXL. So, in case any other
|
|
||||||
# case requires this operator, change this.
|
|
||||||
new_shape.append(height + width)
|
|
||||||
else:
|
else:
|
||||||
new_shape.append(shape[i])
|
new_shape.append(shape[i])
|
||||||
return new_shape
|
return new_shape
|
||||||
@@ -88,7 +84,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
generate_vmfb: bool = True,
|
generate_vmfb: bool = True,
|
||||||
is_inpaint: bool = False,
|
is_inpaint: bool = False,
|
||||||
is_upscaler: bool = False,
|
is_upscaler: bool = False,
|
||||||
is_sdxl: bool = False,
|
|
||||||
use_stencil: str = None,
|
use_stencil: str = None,
|
||||||
use_lora: str = "",
|
use_lora: str = "",
|
||||||
use_quantize: str = None,
|
use_quantize: str = None,
|
||||||
@@ -96,14 +91,8 @@ class SharkifyStableDiffusionModel:
|
|||||||
):
|
):
|
||||||
self.check_params(max_len, width, height)
|
self.check_params(max_len, width, height)
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.is_sdxl = is_sdxl
|
self.height = height // 8
|
||||||
self.height = height
|
self.width = width // 8
|
||||||
self.width = width
|
|
||||||
if is_sdxl:
|
|
||||||
# We need to scale down the height/width by vae_scale_factor, which
|
|
||||||
# happens to be 8 in this case.
|
|
||||||
self.height = height // 8
|
|
||||||
self.width = width // 8
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.custom_weights = custom_weights
|
self.custom_weights = custom_weights
|
||||||
self.use_quantize = use_quantize
|
self.use_quantize = use_quantize
|
||||||
@@ -185,7 +174,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
model_name = {}
|
model_name = {}
|
||||||
sub_model_list = [
|
sub_model_list = [
|
||||||
"clip",
|
"clip",
|
||||||
"clip2",
|
|
||||||
"unet",
|
"unet",
|
||||||
"unet512",
|
"unet512",
|
||||||
"stencil_unet",
|
"stencil_unet",
|
||||||
@@ -353,71 +341,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
)
|
)
|
||||||
return shark_vae, vae_mlir
|
return shark_vae, vae_mlir
|
||||||
|
|
||||||
def get_vae_sdxl(self):
|
|
||||||
class VaeModel(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id=self.model_id,
|
|
||||||
base_vae=self.base_vae,
|
|
||||||
custom_vae=self.custom_vae,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vae = None
|
|
||||||
if custom_vae == "":
|
|
||||||
self.vae = AutoencoderKL.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
subfolder="vae",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
elif not isinstance(custom_vae, dict):
|
|
||||||
self.vae = AutoencoderKL.from_pretrained(
|
|
||||||
custom_vae,
|
|
||||||
subfolder="vae",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.vae = AutoencoderKL.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
subfolder="vae",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
self.vae.load_state_dict(custom_vae)
|
|
||||||
|
|
||||||
def forward(self, latents):
|
|
||||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
|
||||||
0
|
|
||||||
]
|
|
||||||
return image
|
|
||||||
|
|
||||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
|
||||||
inputs = tuple(self.inputs["vae"])
|
|
||||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
|
||||||
# pipeline.
|
|
||||||
is_f16 = False
|
|
||||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
|
||||||
if self.debug:
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
vae_name_split = self.model_name["vae"].split("_")
|
|
||||||
vae_name_split[5] = "fp32"
|
|
||||||
extended_model_name = "_".join(vae_name_split)
|
|
||||||
shark_vae, vae_mlir = compile_through_fx(
|
|
||||||
vae,
|
|
||||||
inputs,
|
|
||||||
is_f16=is_f16,
|
|
||||||
use_tuned=self.use_tuned,
|
|
||||||
extended_model_name=extended_model_name,
|
|
||||||
debug=self.debug,
|
|
||||||
generate_vmfb=self.generate_vmfb,
|
|
||||||
save_dir=save_dir,
|
|
||||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
|
||||||
base_model_id=self.base_model_id,
|
|
||||||
model_name="vae",
|
|
||||||
precision=self.precision,
|
|
||||||
return_mlir=self.return_mlir,
|
|
||||||
)
|
|
||||||
return shark_vae, vae_mlir
|
|
||||||
|
|
||||||
def get_controlled_unet(self, use_large=False):
|
def get_controlled_unet(self, use_large=False):
|
||||||
class ControlledUnetModel(torch.nn.Module):
|
class ControlledUnetModel(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -764,85 +687,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
)
|
)
|
||||||
return shark_unet, unet_mlir
|
return shark_unet, unet_mlir
|
||||||
|
|
||||||
def get_unet_sdxl(self):
|
|
||||||
class UnetModel(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id=self.model_id,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.unet = UNet2DConditionModel.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
subfolder="unet",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
args.attention_slicing is not None
|
|
||||||
and args.attention_slicing != "none"
|
|
||||||
):
|
|
||||||
if args.attention_slicing.isdigit():
|
|
||||||
self.unet.set_attention_slice(
|
|
||||||
int(args.attention_slicing)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.unet.set_attention_slice(args.attention_slicing)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
latent,
|
|
||||||
timestep,
|
|
||||||
prompt_embeds,
|
|
||||||
text_embeds,
|
|
||||||
time_ids,
|
|
||||||
guidance_scale,
|
|
||||||
):
|
|
||||||
added_cond_kwargs = {
|
|
||||||
"text_embeds": text_embeds,
|
|
||||||
"time_ids": time_ids,
|
|
||||||
}
|
|
||||||
noise_pred = self.unet.forward(
|
|
||||||
latent,
|
|
||||||
timestep,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
cross_attention_kwargs=None,
|
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
||||||
noise_pred_text - noise_pred_uncond
|
|
||||||
)
|
|
||||||
return noise_pred
|
|
||||||
|
|
||||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
|
||||||
is_f16 = True if self.precision == "fp16" else False
|
|
||||||
inputs = tuple(self.inputs["unet"])
|
|
||||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
|
||||||
input_mask = [True, True, True, True, True, True]
|
|
||||||
if self.debug:
|
|
||||||
os.makedirs(
|
|
||||||
save_dir,
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
shark_unet, unet_mlir = compile_through_fx(
|
|
||||||
unet,
|
|
||||||
inputs,
|
|
||||||
extended_model_name=self.model_name["unet"],
|
|
||||||
is_f16=is_f16,
|
|
||||||
f16_input_mask=input_mask,
|
|
||||||
use_tuned=self.use_tuned,
|
|
||||||
debug=self.debug,
|
|
||||||
generate_vmfb=self.generate_vmfb,
|
|
||||||
save_dir=save_dir,
|
|
||||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
|
||||||
base_model_id=self.base_model_id,
|
|
||||||
model_name="unet",
|
|
||||||
precision=self.precision,
|
|
||||||
return_mlir=self.return_mlir,
|
|
||||||
)
|
|
||||||
return shark_unet, unet_mlir
|
|
||||||
|
|
||||||
def get_clip(self):
|
def get_clip(self):
|
||||||
class CLIPText(torch.nn.Module):
|
class CLIPText(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -890,68 +734,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
)
|
)
|
||||||
return shark_clip, clip_mlir
|
return shark_clip, clip_mlir
|
||||||
|
|
||||||
def get_clip_sdxl(self, clip_index=1):
|
|
||||||
class CLIPText(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id=self.model_id,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
clip_index=1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if clip_index == 1:
|
|
||||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
subfolder="text_encoder",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.text_encoder = (
|
|
||||||
CLIPTextModelWithProjection.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
subfolder="text_encoder_2",
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
prompt_embeds = self.text_encoder(
|
|
||||||
input,
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
||||||
pooled_prompt_embeds = prompt_embeds[0]
|
|
||||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
|
||||||
|
|
||||||
clip_model = CLIPText(
|
|
||||||
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
|
|
||||||
)
|
|
||||||
if clip_index == 1:
|
|
||||||
model_name = self.model_name["clip"]
|
|
||||||
else:
|
|
||||||
model_name = self.model_name["clip2"]
|
|
||||||
save_dir = os.path.join(self.sharktank_dir, model_name)
|
|
||||||
if self.debug:
|
|
||||||
os.makedirs(
|
|
||||||
save_dir,
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
shark_clip, clip_mlir = compile_through_fx(
|
|
||||||
clip_model,
|
|
||||||
tuple(self.inputs["clip"]),
|
|
||||||
extended_model_name=model_name,
|
|
||||||
debug=self.debug,
|
|
||||||
generate_vmfb=self.generate_vmfb,
|
|
||||||
save_dir=save_dir,
|
|
||||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
|
||||||
base_model_id=self.base_model_id,
|
|
||||||
model_name="clip",
|
|
||||||
precision=self.precision,
|
|
||||||
return_mlir=self.return_mlir,
|
|
||||||
)
|
|
||||||
return shark_clip, clip_mlir
|
|
||||||
|
|
||||||
def process_custom_vae(self):
|
def process_custom_vae(self):
|
||||||
custom_vae = self.custom_vae.lower()
|
custom_vae = self.custom_vae.lower()
|
||||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||||
@@ -984,9 +766,7 @@ class SharkifyStableDiffusionModel:
|
|||||||
}
|
}
|
||||||
return vae_dict
|
return vae_dict
|
||||||
|
|
||||||
def compile_unet_variants(self, model, use_large=False, base_model=""):
|
def compile_unet_variants(self, model, use_large=False):
|
||||||
if self.is_sdxl:
|
|
||||||
return self.get_unet_sdxl()
|
|
||||||
if model == "unet":
|
if model == "unet":
|
||||||
if self.is_upscaler:
|
if self.is_upscaler:
|
||||||
return self.get_unet_upscaler(use_large=use_large)
|
return self.get_unet_upscaler(use_large=use_large)
|
||||||
@@ -1028,22 +808,6 @@ class SharkifyStableDiffusionModel:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
sys.exit(e)
|
sys.exit(e)
|
||||||
|
|
||||||
def sdxl_clip(self):
|
|
||||||
try:
|
|
||||||
self.inputs["clip"] = self.get_input_info_for(
|
|
||||||
base_models["sdxl_clip"]
|
|
||||||
)
|
|
||||||
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
|
|
||||||
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
|
|
||||||
|
|
||||||
check_compilation(compiled_clip, "Clip")
|
|
||||||
check_compilation(compiled_clip, "Clip2")
|
|
||||||
if self.return_mlir:
|
|
||||||
return clip_mlir, clip_mlir2
|
|
||||||
return compiled_clip, compiled_clip2
|
|
||||||
except Exception as e:
|
|
||||||
sys.exit(e)
|
|
||||||
|
|
||||||
def unet(self, use_large=False):
|
def unet(self, use_large=False):
|
||||||
try:
|
try:
|
||||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||||
@@ -1055,7 +819,7 @@ class SharkifyStableDiffusionModel:
|
|||||||
unet_inputs[self.base_model_id]
|
unet_inputs[self.base_model_id]
|
||||||
)
|
)
|
||||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||||
model, use_large=use_large, base_model=self.base_model_id
|
model, use_large=use_large
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for model_id in unet_inputs:
|
for model_id in unet_inputs:
|
||||||
@@ -1066,7 +830,7 @@ class SharkifyStableDiffusionModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||||
model, use_large=use_large, base_model=model_id
|
model, use_large=use_large
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@@ -1105,10 +869,7 @@ class SharkifyStableDiffusionModel:
|
|||||||
is_base_vae = self.base_vae
|
is_base_vae = self.base_vae
|
||||||
if self.is_upscaler:
|
if self.is_upscaler:
|
||||||
self.base_vae = True
|
self.base_vae = True
|
||||||
if self.is_sdxl:
|
compiled_vae, vae_mlir = self.get_vae()
|
||||||
compiled_vae, vae_mlir = self.get_vae_sdxl()
|
|
||||||
else:
|
|
||||||
compiled_vae, vae_mlir = self.get_vae()
|
|
||||||
self.base_vae = is_base_vae
|
self.base_vae = is_base_vae
|
||||||
|
|
||||||
check_compilation(compiled_vae, "Vae")
|
check_compilation(compiled_vae, "Vae")
|
||||||
|
|||||||
@@ -123,8 +123,8 @@ def get_clip():
|
|||||||
return get_shark_model(bucket, model_name, iree_flags)
|
return get_shark_model(bucket, model_name, iree_flags)
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(subfolder="tokenizer"):
|
def get_tokenizer():
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
args.hf_model_id, subfolder=subfolder
|
args.hf_model_id, subfolder="tokenizer"
|
||||||
)
|
)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||||
Text2ImagePipeline,
|
Text2ImagePipeline,
|
||||||
)
|
)
|
||||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
|
||||||
Text2ImageSDXLPipeline,
|
|
||||||
)
|
|
||||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||||
Image2ImagePipeline,
|
Image2ImagePipeline,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from random import randint
|
|
||||||
from typing import Union
|
|
||||||
from diffusers import (
|
|
||||||
DDIMScheduler,
|
|
||||||
PNDMScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
|
||||||
KDPM2DiscreteScheduler,
|
|
||||||
EulerDiscreteScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
|
||||||
DPMSolverMultistepScheduler,
|
|
||||||
DEISMultistepScheduler,
|
|
||||||
DDPMScheduler,
|
|
||||||
DPMSolverSinglestepScheduler,
|
|
||||||
KDPM2AncestralDiscreteScheduler,
|
|
||||||
HeunDiscreteScheduler,
|
|
||||||
)
|
|
||||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
|
||||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
|
||||||
StableDiffusionPipeline,
|
|
||||||
)
|
|
||||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
scheduler: Union[
|
|
||||||
DDIMScheduler,
|
|
||||||
PNDMScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
|
||||||
KDPM2DiscreteScheduler,
|
|
||||||
EulerDiscreteScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
|
||||||
DPMSolverMultistepScheduler,
|
|
||||||
SharkEulerDiscreteScheduler,
|
|
||||||
DEISMultistepScheduler,
|
|
||||||
DDPMScheduler,
|
|
||||||
DPMSolverSinglestepScheduler,
|
|
||||||
KDPM2AncestralDiscreteScheduler,
|
|
||||||
HeunDiscreteScheduler,
|
|
||||||
],
|
|
||||||
sd_model: SharkifyStableDiffusionModel,
|
|
||||||
import_mlir: bool,
|
|
||||||
use_lora: str,
|
|
||||||
ondemand: bool,
|
|
||||||
):
|
|
||||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
|
||||||
|
|
||||||
def prepare_latents(
|
|
||||||
self,
|
|
||||||
batch_size,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
generator,
|
|
||||||
num_inference_steps,
|
|
||||||
dtype,
|
|
||||||
):
|
|
||||||
latents = torch.randn(
|
|
||||||
(
|
|
||||||
batch_size,
|
|
||||||
4,
|
|
||||||
height // 8,
|
|
||||||
width // 8,
|
|
||||||
),
|
|
||||||
generator=generator,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).to(dtype)
|
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
self.scheduler.is_scale_input_called = True
|
|
||||||
latents = latents * self.scheduler.init_noise_sigma
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def _get_add_time_ids(
|
|
||||||
self, original_size, crops_coords_top_left, target_size, dtype
|
|
||||||
):
|
|
||||||
add_time_ids = list(
|
|
||||||
original_size + crops_coords_top_left + target_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.unet.config.addition_time_embed_dim IS 256.
|
|
||||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
|
||||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
|
||||||
expected_add_embed_dim = 2816
|
|
||||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
|
||||||
|
|
||||||
if expected_add_embed_dim != passed_add_embed_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
|
||||||
)
|
|
||||||
|
|
||||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
|
||||||
return add_time_ids
|
|
||||||
|
|
||||||
def generate_images(
|
|
||||||
self,
|
|
||||||
prompts,
|
|
||||||
neg_prompts,
|
|
||||||
batch_size,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
num_inference_steps,
|
|
||||||
guidance_scale,
|
|
||||||
seed,
|
|
||||||
max_length,
|
|
||||||
dtype,
|
|
||||||
use_base_vae,
|
|
||||||
cpu_scheduling,
|
|
||||||
max_embeddings_multiples,
|
|
||||||
):
|
|
||||||
# prompts and negative prompts must be a list.
|
|
||||||
if isinstance(prompts, str):
|
|
||||||
prompts = [prompts]
|
|
||||||
|
|
||||||
if isinstance(neg_prompts, str):
|
|
||||||
neg_prompts = [neg_prompts]
|
|
||||||
|
|
||||||
prompts = prompts * batch_size
|
|
||||||
neg_prompts = neg_prompts * batch_size
|
|
||||||
|
|
||||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
|
||||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
|
||||||
uint32_info = np.iinfo(np.uint32)
|
|
||||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
|
||||||
if seed < uint32_min or seed >= uint32_max:
|
|
||||||
seed = randint(uint32_min, uint32_max)
|
|
||||||
generator = torch.manual_seed(seed)
|
|
||||||
|
|
||||||
# Get initial latents.
|
|
||||||
init_latents = self.prepare_latents(
|
|
||||||
batch_size=batch_size,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
generator=generator,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get text embeddings.
|
|
||||||
(
|
|
||||||
prompt_embeds,
|
|
||||||
negative_prompt_embeds,
|
|
||||||
pooled_prompt_embeds,
|
|
||||||
negative_pooled_prompt_embeds,
|
|
||||||
) = self.encode_prompt_sdxl(
|
|
||||||
prompt=prompts,
|
|
||||||
num_images_per_prompt=1,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
negative_prompt=neg_prompts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare timesteps.
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
# Prepare added time ids & embeddings.
|
|
||||||
original_size = (height, width)
|
|
||||||
target_size = (height, width)
|
|
||||||
crops_coords_top_left = (0, 0)
|
|
||||||
add_text_embeds = pooled_prompt_embeds
|
|
||||||
add_time_ids = self._get_add_time_ids(
|
|
||||||
original_size,
|
|
||||||
crops_coords_top_left,
|
|
||||||
target_size,
|
|
||||||
dtype=prompt_embeds.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_embeds = torch.cat(
|
|
||||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
|
||||||
)
|
|
||||||
add_text_embeds = torch.cat(
|
|
||||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
|
||||||
)
|
|
||||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
|
||||||
|
|
||||||
prompt_embeds = prompt_embeds
|
|
||||||
add_text_embeds = add_text_embeds.to(dtype)
|
|
||||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
|
||||||
|
|
||||||
# guidance scale as a float32 tensor.
|
|
||||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
|
||||||
prompt_embeds = prompt_embeds.to(dtype)
|
|
||||||
add_time_ids = add_time_ids.to(dtype)
|
|
||||||
|
|
||||||
# Get Image latents.
|
|
||||||
latents = self.produce_img_latents_sdxl(
|
|
||||||
init_latents,
|
|
||||||
timesteps,
|
|
||||||
add_text_embeds,
|
|
||||||
add_time_ids,
|
|
||||||
prompt_embeds,
|
|
||||||
cpu_scheduling,
|
|
||||||
guidance_scale,
|
|
||||||
dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Img latents -> PIL images.
|
|
||||||
all_imgs = []
|
|
||||||
self.load_vae()
|
|
||||||
for i in range(0, latents.shape[0], batch_size):
|
|
||||||
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
|
|
||||||
all_imgs.extend(imgs)
|
|
||||||
if self.ondemand:
|
|
||||||
self.unload_vae()
|
|
||||||
|
|
||||||
return all_imgs
|
|
||||||
@@ -33,7 +33,6 @@ from apps.stable_diffusion.src.utils import (
|
|||||||
end_profiling,
|
end_profiling,
|
||||||
)
|
)
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
SD_STATE_IDLE = "idle"
|
SD_STATE_IDLE = "idle"
|
||||||
SD_STATE_CANCEL = "cancel"
|
SD_STATE_CANCEL = "cancel"
|
||||||
@@ -64,7 +63,6 @@ class StableDiffusionPipeline:
|
|||||||
):
|
):
|
||||||
self.vae = None
|
self.vae = None
|
||||||
self.text_encoder = None
|
self.text_encoder = None
|
||||||
self.text_encoder_2 = None
|
|
||||||
self.unet = None
|
self.unet = None
|
||||||
self.unet_512 = None
|
self.unet_512 = None
|
||||||
self.model_max_length = 77
|
self.model_max_length = 77
|
||||||
@@ -108,34 +106,6 @@ class StableDiffusionPipeline:
|
|||||||
del self.text_encoder
|
del self.text_encoder
|
||||||
self.text_encoder = None
|
self.text_encoder = None
|
||||||
|
|
||||||
def load_clip_sdxl(self):
|
|
||||||
if self.text_encoder and self.text_encoder_2:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.import_mlir or self.use_lora:
|
|
||||||
if not self.import_mlir:
|
|
||||||
print(
|
|
||||||
"Warning: LoRA provided but import_mlir not specified. "
|
|
||||||
"Importing MLIR anyways."
|
|
||||||
)
|
|
||||||
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# TODO: Fix this for SDXL
|
|
||||||
self.text_encoder = get_clip()
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print("download pipeline failed, falling back to import_mlir")
|
|
||||||
(
|
|
||||||
self.text_encoder,
|
|
||||||
self.text_encoder_2,
|
|
||||||
) = self.sd_model.sdxl_clip()
|
|
||||||
|
|
||||||
def unload_clip_sdxl(self):
|
|
||||||
del self.text_encoder, self.text_encoder_2
|
|
||||||
self.text_encoder = None
|
|
||||||
self.text_encoder_2 = None
|
|
||||||
|
|
||||||
def load_unet(self):
|
def load_unet(self):
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
return
|
return
|
||||||
@@ -190,177 +160,6 @@ class StableDiffusionPipeline:
|
|||||||
del self.vae
|
del self.vae
|
||||||
self.vae = None
|
self.vae = None
|
||||||
|
|
||||||
def encode_prompt_sdxl(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
num_images_per_prompt: int = 1,
|
|
||||||
do_classifier_free_guidance: bool = True,
|
|
||||||
negative_prompt: Optional[str] = None,
|
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
):
|
|
||||||
if prompt is not None and isinstance(prompt, str):
|
|
||||||
batch_size = 1
|
|
||||||
elif prompt is not None and isinstance(prompt, list):
|
|
||||||
batch_size = len(prompt)
|
|
||||||
else:
|
|
||||||
batch_size = prompt_embeds.shape[0]
|
|
||||||
|
|
||||||
# Define tokenizers and text encoders
|
|
||||||
self.tokenizer_2 = get_tokenizer("tokenizer_2")
|
|
||||||
self.load_clip_sdxl()
|
|
||||||
tokenizers = (
|
|
||||||
[self.tokenizer, self.tokenizer_2]
|
|
||||||
if self.tokenizer is not None
|
|
||||||
else [self.tokenizer_2]
|
|
||||||
)
|
|
||||||
text_encoders = (
|
|
||||||
[self.text_encoder, self.text_encoder_2]
|
|
||||||
if self.text_encoder is not None
|
|
||||||
else [self.text_encoder_2]
|
|
||||||
)
|
|
||||||
|
|
||||||
# textual inversion: procecss multi-vector tokens if necessary
|
|
||||||
prompt_embeds_list = []
|
|
||||||
prompts = [prompt, prompt]
|
|
||||||
for prompt, tokenizer, text_encoder in zip(
|
|
||||||
prompts, tokenizers, text_encoders
|
|
||||||
):
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
untruncated_ids = tokenizer(
|
|
||||||
prompt, padding="longest", return_tensors="pt"
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
|
||||||
-1
|
|
||||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
|
||||||
removed_text = tokenizer.batch_decode(
|
|
||||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
||||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_output = text_encoder("forward", (text_input_ids,))
|
|
||||||
prompt_embeds = torch.from_numpy(text_encoder_output[0])
|
|
||||||
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
|
|
||||||
|
|
||||||
prompt_embeds_list.append(prompt_embeds)
|
|
||||||
|
|
||||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
|
||||||
zero_out_negative_prompt = (
|
|
||||||
negative_prompt is None
|
|
||||||
and self.config.force_zeros_for_empty_prompt
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
do_classifier_free_guidance
|
|
||||||
and negative_prompt_embeds is None
|
|
||||||
and zero_out_negative_prompt
|
|
||||||
):
|
|
||||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
|
||||||
negative_pooled_prompt_embeds = torch.zeros_like(
|
|
||||||
pooled_prompt_embeds
|
|
||||||
)
|
|
||||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
||||||
negative_prompt = negative_prompt or ""
|
|
||||||
negative_prompt_2 = negative_prompt
|
|
||||||
|
|
||||||
uncond_tokens: List[str]
|
|
||||||
if prompt is not None and type(prompt) is not type(
|
|
||||||
negative_prompt
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
||||||
f" {type(prompt)}."
|
|
||||||
)
|
|
||||||
elif isinstance(negative_prompt, str):
|
|
||||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
|
||||||
elif batch_size != len(negative_prompt):
|
|
||||||
raise ValueError(
|
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
||||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
||||||
" the batch size of `prompt`."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
|
||||||
|
|
||||||
negative_prompt_embeds_list = []
|
|
||||||
for negative_prompt, tokenizer, text_encoder in zip(
|
|
||||||
uncond_tokens, tokenizers, text_encoders
|
|
||||||
):
|
|
||||||
max_length = prompt_embeds.shape[1]
|
|
||||||
uncond_input = tokenizer(
|
|
||||||
negative_prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_encoder_output = text_encoder(
|
|
||||||
"forward", (uncond_input.input_ids,)
|
|
||||||
)
|
|
||||||
negative_prompt_embeds = torch.from_numpy(
|
|
||||||
text_encoder_output[0]
|
|
||||||
)
|
|
||||||
negative_pooled_prompt_embeds = torch.from_numpy(
|
|
||||||
text_encoder_output[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
|
||||||
|
|
||||||
negative_prompt_embeds = torch.concat(
|
|
||||||
negative_prompt_embeds_list, dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ondemand:
|
|
||||||
self.unload_clip_sdxl()
|
|
||||||
|
|
||||||
# TODO: Look into dtype for text_encoder_2!
|
|
||||||
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
|
|
||||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
||||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
||||||
prompt_embeds = prompt_embeds.view(
|
|
||||||
bs_embed * num_images_per_prompt, seq_len, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
||||||
seq_len = negative_prompt_embeds.shape[1]
|
|
||||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
|
|
||||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
|
||||||
1, num_images_per_prompt, 1
|
|
||||||
)
|
|
||||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
|
||||||
batch_size * num_images_per_prompt, seq_len, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
|
|
||||||
1, num_images_per_prompt
|
|
||||||
).view(bs_embed * num_images_per_prompt, -1)
|
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
|
|
||||||
1, num_images_per_prompt
|
|
||||||
).view(bs_embed * num_images_per_prompt, -1)
|
|
||||||
|
|
||||||
return (
|
|
||||||
prompt_embeds,
|
|
||||||
negative_prompt_embeds,
|
|
||||||
pooled_prompt_embeds,
|
|
||||||
negative_pooled_prompt_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||||
# Tokenize text and get embeddings
|
# Tokenize text and get embeddings
|
||||||
text_input = self.tokenizer(
|
text_input = self.tokenizer(
|
||||||
@@ -507,69 +306,6 @@ class StableDiffusionPipeline:
|
|||||||
all_latents = torch.cat(latent_history, dim=0)
|
all_latents = torch.cat(latent_history, dim=0)
|
||||||
return all_latents
|
return all_latents
|
||||||
|
|
||||||
def produce_img_latents_sdxl(
|
|
||||||
self,
|
|
||||||
latents,
|
|
||||||
total_timesteps,
|
|
||||||
add_text_embeds,
|
|
||||||
add_time_ids,
|
|
||||||
prompt_embeds,
|
|
||||||
cpu_scheduling,
|
|
||||||
guidance_scale,
|
|
||||||
dtype,
|
|
||||||
):
|
|
||||||
self.status = SD_STATE_IDLE
|
|
||||||
step_time_sum = 0
|
|
||||||
extra_step_kwargs = {"generator": None}
|
|
||||||
self.load_unet()
|
|
||||||
for i, t in tqdm(enumerate(total_timesteps)):
|
|
||||||
step_start_time = time.time()
|
|
||||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
|
||||||
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(
|
|
||||||
latent_model_input, t
|
|
||||||
).to(dtype)
|
|
||||||
|
|
||||||
noise_pred = self.unet(
|
|
||||||
"forward",
|
|
||||||
(
|
|
||||||
latent_model_input,
|
|
||||||
timestep,
|
|
||||||
prompt_embeds,
|
|
||||||
add_text_embeds,
|
|
||||||
add_time_ids,
|
|
||||||
guidance_scale,
|
|
||||||
),
|
|
||||||
send_to_host=False,
|
|
||||||
)
|
|
||||||
latents = self.scheduler.step(
|
|
||||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
step_time = (time.time() - step_start_time) * 1000
|
|
||||||
step_time_sum += step_time
|
|
||||||
|
|
||||||
if self.status == SD_STATE_CANCEL:
|
|
||||||
break
|
|
||||||
if self.ondemand:
|
|
||||||
self.unload_unet()
|
|
||||||
avg_step_time = step_time_sum / len(total_timesteps)
|
|
||||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def decode_latents_sdxl(self, latents):
|
|
||||||
latents = latents.to(torch.float32)
|
|
||||||
images = self.vae("forward", (latents,))
|
|
||||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
|
||||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -619,7 +355,6 @@ class StableDiffusionPipeline:
|
|||||||
"OutpaintPipeline",
|
"OutpaintPipeline",
|
||||||
]
|
]
|
||||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||||
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
|
|
||||||
|
|
||||||
sd_model = SharkifyStableDiffusionModel(
|
sd_model = SharkifyStableDiffusionModel(
|
||||||
model_id,
|
model_id,
|
||||||
@@ -636,7 +371,6 @@ class StableDiffusionPipeline:
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
is_inpaint=is_inpaint,
|
is_inpaint=is_inpaint,
|
||||||
is_upscaler=is_upscaler,
|
is_upscaler=is_upscaler,
|
||||||
is_sdxl=is_sdxl,
|
|
||||||
use_stencil=use_stencil,
|
use_stencil=use_stencil,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
use_quantize=use_quantize,
|
use_quantize=use_quantize,
|
||||||
|
|||||||
@@ -8,15 +8,6 @@
|
|||||||
"dtype":"i64"
|
"dtype":"i64"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sdxl_clip": {
|
|
||||||
"token" : {
|
|
||||||
"shape" : [
|
|
||||||
"1*batch_size",
|
|
||||||
"max_len"
|
|
||||||
],
|
|
||||||
"dtype":"i64"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"vae_encode": {
|
"vae_encode": {
|
||||||
"image" : {
|
"image" : {
|
||||||
"shape" : [
|
"shape" : [
|
||||||
@@ -188,49 +179,6 @@
|
|||||||
"shape": [2],
|
"shape": [2],
|
||||||
"dtype": "i64"
|
"dtype": "i64"
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
|
||||||
"latents": {
|
|
||||||
"shape": [
|
|
||||||
"2*batch_size",
|
|
||||||
4,
|
|
||||||
"height",
|
|
||||||
"width"
|
|
||||||
],
|
|
||||||
"dtype": "f32"
|
|
||||||
},
|
|
||||||
"timesteps": {
|
|
||||||
"shape": [
|
|
||||||
1
|
|
||||||
],
|
|
||||||
"dtype": "f32"
|
|
||||||
},
|
|
||||||
"prompt_embeds": {
|
|
||||||
"shape": [
|
|
||||||
"2*batch_size",
|
|
||||||
"max_len",
|
|
||||||
2048
|
|
||||||
],
|
|
||||||
"dtype": "f32"
|
|
||||||
},
|
|
||||||
"text_embeds": {
|
|
||||||
"shape": [
|
|
||||||
"2*batch_size",
|
|
||||||
1280
|
|
||||||
],
|
|
||||||
"dtype": "f32"
|
|
||||||
},
|
|
||||||
"time_ids": {
|
|
||||||
"shape": [
|
|
||||||
"2*batch_size",
|
|
||||||
6
|
|
||||||
],
|
|
||||||
"dtype": "f32"
|
|
||||||
},
|
|
||||||
"guidance_scale": {
|
|
||||||
"shape": 1,
|
|
||||||
"dtype": "f32"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"stencil_adaptor": {
|
"stencil_adaptor": {
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ p.add_argument(
|
|||||||
"--height",
|
"--height",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
choices=range(128, 1025, 8),
|
choices=range(128, 769, 8),
|
||||||
help="The height of the output image.",
|
help="The height of the output image.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ p.add_argument(
|
|||||||
"--width",
|
"--width",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
choices=range(128, 1025, 8),
|
choices=range(128, 769, 8),
|
||||||
help="The width of the output image.",
|
help="The width of the output image.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
|||||||
from apps.stable_diffusion.src import (
|
from apps.stable_diffusion.src import (
|
||||||
args,
|
args,
|
||||||
Text2ImagePipeline,
|
Text2ImagePipeline,
|
||||||
Text2ImageSDXLPipeline,
|
|
||||||
get_schedulers,
|
get_schedulers,
|
||||||
set_init_device_flags,
|
set_init_device_flags,
|
||||||
utils,
|
utils,
|
||||||
@@ -160,37 +159,8 @@ def txt2img_inf(
|
|||||||
)
|
)
|
||||||
global_obj.set_schedulers(get_schedulers(model_id))
|
global_obj.set_schedulers(get_schedulers(model_id))
|
||||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||||
if height == 1024:
|
global_obj.set_sd_obj(
|
||||||
assert (
|
Text2ImagePipeline.from_pretrained(
|
||||||
width == 1024
|
|
||||||
), "currently we support only 1024x1024 image size via SDXL"
|
|
||||||
assert precision == "fp16", "currently we support fp16 for SDXL"
|
|
||||||
# For SDXL we set max_length as 77.
|
|
||||||
max_length = 77
|
|
||||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
|
||||||
scheduler=scheduler_obj,
|
|
||||||
import_mlir=args.import_mlir,
|
|
||||||
model_id=args.hf_model_id,
|
|
||||||
ckpt_loc=args.ckpt_loc,
|
|
||||||
precision=precision,
|
|
||||||
max_length=max_length,
|
|
||||||
batch_size=batch_size,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
use_base_vae=args.use_base_vae,
|
|
||||||
use_tuned=args.use_tuned,
|
|
||||||
custom_vae=args.custom_vae,
|
|
||||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
|
||||||
debug=args.import_debug if args.import_mlir else False,
|
|
||||||
use_lora=args.use_lora,
|
|
||||||
use_quantize=args.use_quantize,
|
|
||||||
ondemand=args.ondemand,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
height <= 768 and width <= 768
|
|
||||||
), "height/width not in supported range"
|
|
||||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
|
||||||
scheduler=scheduler_obj,
|
scheduler=scheduler_obj,
|
||||||
import_mlir=args.import_mlir,
|
import_mlir=args.import_mlir,
|
||||||
model_id=args.hf_model_id,
|
model_id=args.hf_model_id,
|
||||||
@@ -198,18 +168,17 @@ def txt2img_inf(
|
|||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
height=height,
|
height=args.height,
|
||||||
width=width,
|
width=args.width,
|
||||||
use_base_vae=args.use_base_vae,
|
use_base_vae=args.use_base_vae,
|
||||||
use_tuned=args.use_tuned,
|
use_tuned=args.use_tuned,
|
||||||
custom_vae=args.custom_vae,
|
custom_vae=args.custom_vae,
|
||||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||||
debug=args.import_debug if args.import_mlir else False,
|
debug=args.import_debug if args.import_mlir else False,
|
||||||
use_lora=args.use_lora,
|
use_lora=args.use_lora,
|
||||||
use_quantize=args.use_quantize,
|
|
||||||
ondemand=args.ondemand,
|
ondemand=args.ondemand,
|
||||||
)
|
)
|
||||||
global_obj.set_sd_obj(txt2img_obj)
|
)
|
||||||
|
|
||||||
global_obj.set_sd_scheduler(scheduler)
|
global_obj.set_sd_scheduler(scheduler)
|
||||||
|
|
||||||
@@ -533,15 +502,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
height = gr.Slider(
|
height = gr.Slider(
|
||||||
128,
|
384,
|
||||||
1024,
|
768,
|
||||||
value=args.height,
|
value=args.height,
|
||||||
step=8,
|
step=8,
|
||||||
label="Height",
|
label="Height",
|
||||||
)
|
)
|
||||||
width = gr.Slider(
|
width = gr.Slider(
|
||||||
128,
|
384,
|
||||||
1024,
|
768,
|
||||||
value=args.width,
|
value=args.width,
|
||||||
step=8,
|
step=8,
|
||||||
label="Width",
|
label="Width",
|
||||||
|
|||||||
Reference in New Issue
Block a user