Merge branch 'main' into ean-turbine-gen

This commit is contained in:
Ean Garvey
2023-12-06 18:41:11 -06:00
committed by GitHub
44 changed files with 4083 additions and 580 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,5 @@
import torch
import time
class FirstVicunaLayer(torch.nn.Module):
@@ -66,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -110,9 +110,11 @@ class LMHeadCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
hidden_states_sample = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
@@ -136,8 +138,9 @@ class VicunaNormCompiled(torch.nn.Module):
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = self.model("forward", (hidden_states,), send_to_host=True)
output = torch.tensor(output)
return output
@@ -158,15 +161,18 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = self.model("forward", (input_ids,), send_to_host=True)
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
def __init__(self, shark_module, idx, breakpoints):
super().__init__()
self.model = shark_module
self.idx = idx
self.breakpoints = breakpoints
def forward(
self,
@@ -177,10 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
if self.breakpoints is None:
is_breakpoint = False
else:
is_breakpoint = self.idx + 1 in self.breakpoints
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
@@ -188,11 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
attention_mask,
position_ids,
),
send_to_host=is_breakpoint,
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,
@@ -202,11 +215,8 @@ class CompiledVicunaLayer(torch.nn.Module):
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
"second_vicuna_forward",
(
@@ -216,11 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
pkv0,
pkv1,
),
send_to_host=is_breakpoint,
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,

View File

@@ -105,6 +105,7 @@ def main():
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
control_mode=args.control_mode,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -31,6 +31,7 @@ datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
@@ -75,6 +76,7 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
@@ -85,4 +87,4 @@ hiddenimports += [
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
hiddenimports += ["iree._runtime"]

View File

@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from collections import defaultdict
from pathlib import Path
import torch
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
get_stencil_model_id,
update_lora_weight,
)
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
# These shapes are parameter dependent.
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
)
def shark_compile_after_ir(
module_name,
device,
vmfb_path,
precision,
ir_path=None,
):
if ir_path:
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
module = SharkInference(
mlir_module=ir_path,
device=device,
mlir_dialect="tm_tensor",
)
print(f"Will get extra flag for {module_name} and precision = {precision}")
path = module.save_module(
vmfb_path.parent.absolute(),
vmfb_path.stem,
extra_args=get_opt_flags(module_name, precision=precision),
)
print(f"Saved {module_name} vmfb at {path}")
module.load_module(path)
return module
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
name_split = extended_model_name.split("_")
if "vae" in model_name:
name_split[5] = "fp32"
extended_model_name_for_vmfb = "_".join(name_split)
extended_model_name_for_mlir = "_".join(name_split[:-1])
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
device=device,
mlir_dialect="tm_tensor",
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=[])
return shark_module, None
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
if not mlir_path.exists():
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
download_public_file(
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
return (
shark_compile_after_ir(
model_name, device, vmfb_path, precision, mlir_path
),
None,
)
return None, None
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -86,13 +156,15 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
@@ -144,7 +216,7 @@ class SharkifyStableDiffusionModel:
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
@@ -175,14 +247,15 @@ class SharkifyStableDiffusionModel:
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
"stencil_adapter",
"stencil_adapter_512",
]
index = 0
for model in sub_model_list:
@@ -195,10 +268,19 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
if "stencil_adapter" in model:
stencil_names = []
for i, stencil in enumerate(self.stencils):
if stencil is not None:
cnet_config = model_config + stencil.split("_")[-1]
stencil_names.append(
get_extended_name(sub_model + cnet_config)
)
model_name[model] = stencil_names
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
@@ -342,6 +424,105 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["vae"], "vae", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
print(f"Loading default vae, with target {model_id}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
custom_vae = "/".join(
[
custom_vae.split("/")[-2].split("\\")[-1],
custom_vae.split("/")[-1],
]
)
print("Using hub to get custom vae")
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
if not self.custom_vae:
is_f16 = False
elif "16" in self.custom_vae:
is_f16 = True
else:
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -380,25 +561,54 @@ class SharkifyStableDiffusionModel:
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control1 * scale1,
control2 * scale2,
control3 * scale3,
control4 * scale4,
control5 * scale5,
control6 * scale6,
control7 * scale7,
control8 * scale8,
control9 * scale9,
control10 * scale10,
control11 * scale11,
control12 * scale12,
]
)
mb_res_samples = control13
mb_res_samples = control13 * scale13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
@@ -446,6 +656,19 @@ class SharkifyStableDiffusionModel:
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
@@ -462,17 +685,19 @@ class SharkifyStableDiffusionModel:
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
print(f"Importing ControlNet adapter from {stencil_id}")
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.in_channels = self.cnet.config.in_channels
self.train(False)
def forward(
@@ -481,6 +706,19 @@ class SharkifyStableDiffusionModel:
timestep,
text_embedding,
stencil_image_input,
acc1,
acc2,
acc3,
acc4,
acc5,
acc6,
acc7,
acc8,
acc9,
acc10,
acc11,
acc12,
acc13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
@@ -502,6 +740,20 @@ class SharkifyStableDiffusionModel:
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
) + (
acc1 + down_block_res_samples[0],
acc2 + down_block_res_samples[1],
acc3 + down_block_res_samples[2],
acc4 + down_block_res_samples[3],
acc5 + down_block_res_samples[4],
acc6 + down_block_res_samples[5],
acc7 + down_block_res_samples[6],
acc8 + down_block_res_samples[7],
acc9 + down_block_res_samples[8],
acc10 + down_block_res_samples[9],
acc11 + down_block_res_samples[10],
acc12 + down_block_res_samples[11],
acc13 + mid_block_res_sample,
)
scnet = StencilControlNetModel(
@@ -509,7 +761,23 @@ class SharkifyStableDiffusionModel:
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
if isinstance(ext_model_name, list):
for i in ext_model_name:
if stencil_id.split("_")[-1] in i:
desired_name = i
print(f"Multi-CN: compiling model {i}")
else:
continue
if desired_name:
ext_model_name = desired_name
else:
raise Exception(
f"Could not find extended configuration for {stencil_id}"
)
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
@@ -517,21 +785,15 @@ class SharkifyStableDiffusionModel:
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
input_mask = [True, True, True, True] + ([True] * 13)
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=ext_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -688,6 +950,101 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["unet"], "unet", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -735,6 +1092,78 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
if clip_index == 1:
extended_model_name = self.model_name["clip"]
model_name = "clip"
else:
extended_model_name = self.model_name["clip2"]
model_name = "clip2"
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
extended_model_name, f"clip", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -767,7 +1196,9 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -809,9 +1240,28 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]
@@ -820,7 +1270,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
@@ -831,7 +1281,7 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=model_id
)
except Exception as e:
print(e)
@@ -870,7 +1320,10 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
@@ -880,18 +1333,18 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
self.inputs["stencil_adapter"] = self.get_input_info_for(
base_models["stencil_adapter"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
check_compilation(compiled_stencil_adapter, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
return compiled_stencil_adapter
except Exception as e:
sys.exit(e)

View File

@@ -123,8 +123,11 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
args.hf_model_id, subfolder=subfolder
)
return tokenizer

View File

@@ -1,6 +1,9 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -158,8 +158,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
images,
resample_type,
control_mode,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):

View File

@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
import_mlir: bool,
use_lora: str,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
def load_controlnet(self):
if self.controlnet is not None:
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
@@ -111,8 +130,9 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
control_mode="Balanced", # Prompt, Balanced, or Controlnet
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -121,12 +141,18 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, name in enumerate(self.controlnet_names):
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -149,33 +175,93 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
# Multicontrolnet
width = latent_model_input_1.shape[2]
height = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
+ [
torch.zeros(
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
)
]
* 4
)
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
continue
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
else:
control = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
control_acc = control[13:]
control = control[:13]
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
dtype = latents.dtype
if control_mode == "Balanced":
control_scale = [
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
]
elif control_mode == "Prompt":
control_scale = [
torch.tensor(0.825**x, dtype=dtype)
for x in range(len(control))
]
elif control_mode == "Controlnet":
control_scale = [
torch.tensor(float(guidance_scale), dtype=dtype)
for _ in range(len(control))
]
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
@@ -197,6 +283,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -222,6 +321,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -245,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -272,14 +385,29 @@ class StencilPipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
control_mode,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
for i, stencil in enumerate(stencils):
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -327,7 +455,8 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
control_mode=control_mode,
stencil_hints=stencil_hints,
)
# Img latents -> PIL images

View File

@@ -18,7 +18,10 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)

View File

@@ -0,0 +1,220 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.is_fp32_vae = is_fp32_vae
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -20,7 +20,10 @@ from diffusers import (
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
@@ -33,6 +36,8 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -50,6 +55,7 @@ class StableDiffusionPipeline:
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
@@ -60,20 +66,23 @@ class StableDiffusionPipeline:
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_f32_vae: bool = False,
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
@@ -106,6 +115,34 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -159,6 +196,182 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -186,6 +399,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -298,6 +512,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -306,6 +522,96 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# return None
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
if isinstance(latents, np.ndarray):
latents = torch.tensor(latents)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=True,
)
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to("cpu")
noise_pred = torch.from_numpy(noise_pred).to("cpu")
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.detach().numpy()
noise_pred = noise_pred.detach().numpy()
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents, is_fp32_vae):
# latents are in unet dtype here so switch if we want to use fp32
if is_fp32_vae:
print("Casting latents to float32 for VAE")
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -338,7 +644,8 @@ class StableDiffusionPipeline:
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
@@ -355,6 +662,7 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
@@ -371,7 +679,8 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
)
@@ -386,6 +695,21 @@ class StableDiffusionPipeline:
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
is_fp32_vae,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
@@ -498,9 +822,10 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
return text_embeddings.numpy().astype(np.float16)
from typing import List, Optional, Union

View File

@@ -1,4 +1,7 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers

View File

@@ -1,4 +1,5 @@
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
@@ -15,9 +16,21 @@ from diffusers import (
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: Robust scheduler setup on pipeline creation -- if we don't
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.
batch_size = 2 if "xl" in model_id.lower() else 1
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerAncestralDiscrete"
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
schedulers["SharkEulerDiscrete"].compile(batch_size)
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
return schedulers

View File

@@ -0,0 +1,251 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
EulerAncestralDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
timestep_spacing,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
self.init_input_shape = None
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"eulera": {
"output": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"latent": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"sigma_from": torch.tensor(1).to(torch.float32),
"sigma_to": torch.tensor(1).to(torch.float32),
"noise": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
},
}
example_latent = model_input["eulera"]["latent"]
example_output = model_input["eulera"]["output"]
example_noise = model_input["eulera"]["noise"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_noise = example_noise.half()
example_sigma = model_input["eulera"]["sigma"]
example_sigma_from = model_input["eulera"]["sigma_from"]
example_sigma_to = model_input["eulera"]["sigma_to"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(
example_output,
example_latent,
example_sigma,
example_sigma_from,
example_sigma_to,
example_noise,
),
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_step_"
+ self.config.prediction_type
+ args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(
self,
noise_pred,
timestep,
latent,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
step_inputs = []
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
step_inputs = [
noise_pred,
latent,
sigma,
sigma_from,
sigma_to,
noise,
]
# TODO: deal with dynamic inputs in turbine flow.
# update step index since we're done with the variable and will return with compiled module output.
self._step_index += 1
if noise_pred.shape[0] < self.batch_size:
for i in [0, 1, 5]:
try:
step_inputs[i] = torch.tensor(step_inputs[i])
except:
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
step_inputs[i] = torch.cat(
(step_inputs[i], step_inputs[i]), axis=0
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=True,
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=False,
)

View File

@@ -2,12 +2,9 @@ import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule,
trained_betas,
prediction_type,
interpolation_type,
use_karras_sigmas,
sigma_min,
sigma_max,
timestep_spacing,
timestep_type,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = 1
def compile(self):
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = latent - sigma_hat * noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepSampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
@@ -90,16 +123,22 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
"sample": SchedulerStepSampleModel(),
"original_sample": SchedulerStepSampleModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
@@ -109,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
else:
try:
step_model_type = (
"sample"
if "sample" in self.config.prediction_type
else self.config.prediction_type
)
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
@@ -116,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
"euler_step_" + step_model_type + args.precision,
iree_flags,
)
except:
@@ -127,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
@@ -138,15 +183,61 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
def step(
self,
noise_pred,
timestep,
latent,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
gamma = (
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigma <= s_tmax
else 0.0
)
sigma_hat = sigma * (gamma + 1)
noise_pred = (
torch.from_numpy(noise_pred)
if isinstance(noise_pred, np.ndarray)
else noise_pred
)
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
eps = noise * s_noise
if gamma > 0:
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.config.prediction_type == "v_prediction":
sigma_hat = sigma
dt = self.sigmas[self.step_index + 1] - sigma_hat
self._step_index += 1
return self.step_model(
"forward",
(
noise_pred,
sigma,
sigma_hat,
latent,
dt,
),

View File

@@ -8,6 +8,15 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -179,9 +188,95 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/sdxl-turbo": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {
"stencil_adapter": {
"latents": {
"shape": [
"1*batch_size",
@@ -208,6 +303,58 @@
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
},
"acc1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"acc5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"acc8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"stencil_unet": {
@@ -290,7 +437,59 @@
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"scale1": {
"shape": 1,
"dtype": "f32"
},
"scale2": {
"shape": 1,
"dtype": "f32"
},
"scale3": {
"shape": 1,
"dtype": "f32"
},
"scale4": {
"shape": 1,
"dtype": "f32"
},
"scale5": {
"shape": 1,
"dtype": "f32"
},
"scale6": {
"shape": 1,
"dtype": "f32"
},
"scale7": {
"shape": 1,
"dtype": "f32"
},
"scale8": {
"shape": 1,
"dtype": "f32"
},
"scale9": {
"shape": 1,
"dtype": "f32"
},
"scale10": {
"shape": 1,
"dtype": "f32"
},
"scale11": {
"shape": 1,
"dtype": "f32"
},
"scale12": {
"shape": 1,
"dtype": "f32"
},
"scale13": {
"shape": 1,
"dtype": "f32"
}
}
}
}
}

View File

@@ -1,4 +1,5 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],

View File

@@ -85,7 +85,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The height of the output image.",
)
@@ -93,7 +93,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The width of the output image.",
)
@@ -420,6 +420,13 @@ p.add_argument(
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
@@ -460,6 +467,13 @@ p.add_argument(
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
p.add_argument(
"--autogen",
type=bool,
default="False",
help="Only used for a gradio workaround.",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################

View File

@@ -79,10 +79,12 @@ def controlnet_hint_shaping(
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
return controlnet_hint_shaping(
Image.fromarray(controlnet_hint.detach().numpy()),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
@@ -109,29 +111,36 @@ def controlnet_hint_shaping(
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
Image.fromarray(controlnet_hint),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, Image.Image):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
if controlnet_hint.size == (width, height):
controlnet_hint = np.array(controlnet_hint).astype(
np.float16
) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
(hint_w, hint_h) = controlnet_hint.size
left = int((hint_w - width) / 2)
right = left + height
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
controlnet_hint = controlnet_hint.resize((width, height))
return controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@@ -141,16 +150,22 @@ def controlnet_hint_conversion(
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
print(
"Converting controlnet hint to edge detection mask with canny preprocessor."
)
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
print(
"Detecting human pose in controlnet hint with openpose preprocessor."
)
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
print("Using your scribble as a controlnet hint.")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
print(
"Converting controlnet hint to a depth mapping with ZoeDepth."
)
controlnet_hint = hint_zoedepth(image)
case _:
return None

View File

@@ -30,9 +30,15 @@ class ZoeDetector:
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
# Hack to fix the ZoeDepth import issue
model_keys = model.state_dict().keys()
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
loaded_keys = loaded_dict.keys()
for key in loaded_keys - model_keys:
loaded_dict.pop(key)
model.load_state_dict(loaded_dict)
model.eval()
self.model = model

View File

@@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -97,8 +97,6 @@ if __name__ == "__main__":
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
@@ -109,6 +107,16 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_png_info_img,
txt2img_sdxl_status,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
@@ -145,7 +153,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
# lora_train_web,
# model_web,
# model_config_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -159,6 +167,7 @@ if __name__ == "__main__":
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
@@ -172,7 +181,7 @@ if __name__ == "__main__":
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
@@ -183,7 +192,7 @@ if __name__ == "__main__":
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
@@ -193,12 +202,14 @@ if __name__ == "__main__":
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
@@ -235,6 +246,7 @@ if __name__ == "__main__":
inpaint_status,
outpaint_status,
upscaler_status,
txt2img_sdxl_status,
]
)
# with gr.TabItem(label="Model Manager", id=6):
@@ -243,16 +255,18 @@ if __name__ == "__main__":
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
txt2img_sdxl_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:
@@ -391,6 +405,12 @@ if __name__ == "__main__":
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_txt2img_sdxl,
0,
[outputgallery_filename],
[txt2img_sdxl_png_info_img, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,

View File

@@ -10,6 +10,18 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
txt2img_sdxl_png_info_img,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_web,
@@ -76,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,

View File

@@ -129,10 +129,6 @@ body {
padding: 0 var(--size-4) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
color: transparent;
background-color: transparent;
@@ -140,6 +136,10 @@ body {
border: 0;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#demo_title_outer {
border-radius: 0;
}
@@ -234,11 +234,6 @@ footer {
display:none;
}
/* Hide the download icon from the nod logo */
#top_logo button {
display: none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;

View File

@@ -6,6 +6,12 @@ import PIL
from math import ceil
from PIL import Image
from gradio.components.image_editor import (
Brush,
Eraser,
EditorData,
EditorValue,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -31,6 +37,11 @@ from apps.stable_diffusion.src.utils import (
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -60,7 +71,6 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
@@ -68,6 +78,9 @@ def img2img_inf(
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -89,14 +102,23 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
if image_dict is None and images[0] is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
# TODO: enable t2i + controlnets
image = None
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -123,10 +145,11 @@ def img2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
stencil_count = 0
for stencil in stencils:
if stencil is not None:
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
@@ -150,7 +173,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
stencils=stencils,
ondemand=ondemand,
)
if (
@@ -172,12 +195,12 @@ def img2img_inf(
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
else "stabilityai/stable-diffusion-1-5-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
if stencil_count > 0:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
@@ -194,7 +217,7 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
@@ -251,8 +274,10 @@ def img2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
stencils,
images,
resample_type=resample_type,
control_mode=control_mode,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -272,12 +297,17 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
), stencils, images
return generated_imgs, text_output, ""
return generated_imgs, text_output, "", stencils, images
with gr.Blocks(title="Image-to-Image") as img2img_web:
# Stencils
# TODO: Add more stencils here
STENCIL_COUNT = 2
stencils = gr.State([None] * STENCIL_COUNT)
images = gr.State([None] * STENCIL_COUNT)
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -289,6 +319,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
@@ -342,75 +373,282 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
tool="sketch",
type="pil",
height=300,
height=512,
interactive=True,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Accordion(label="Multistencil Options", open=False):
choices = [
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
def cnet_preview(
model, input_image, index, stencils, images
):
images[index] = input_image
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(
np.array(input_image["composite"]),
100,
200,
)
return (
Image.fromarray(result),
stencils,
images,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(
np.array(input_image["composite"])
)
print(result)
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
return (
Image.fromarray(result[0]),
stencils,
images,
)
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(
np.array(input_image["composite"])
)
return (
Image.fromarray(result),
stencils,
images,
)
case "scribble":
return (
input_image["composite"],
stencils,
images,
)
case _:
return (None, stencils, images)
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [data],
"composite": None,
}
return EditorValue(img_dict)
def update_cn_input(model, width, height):
if model == "scribble":
return [
gr.ImageEditor(
visible=True,
interactive=True,
show_label=False,
image_mode="RGB",
type="pil",
value=create_canvas(width, height),
brush=Brush(
colors=["#000000"], color_mode="fixed"
),
),
gr.Image(
visible=True,
show_label=False,
interactive=False,
show_download_button=False,
),
gr.Slider(visible=True),
gr.Slider(visible=True),
gr.Button(visible=True),
]
else:
return [
gr.ImageEditor(
visible=True,
image_mode="RGB",
type="pil",
interactive=True,
value=None,
),
gr.Image(
visible=True,
show_label=False,
interactive=True,
show_download_button=False,
),
gr.Slider(visible=False),
gr.Slider(visible=False),
gr.Button(visible=False),
]
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
with gr.Column():
cnet_1 = gr.Button(
value="Generate controlnet input"
)
cnet_1_model = gr.Dropdown(
label="Controlnet 1",
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
cnet_1_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
show_label=False,
type="pil",
)
cnet_1_output = gr.Image(
visible=True, show_label=False
)
cnet_1_model.input(
update_cn_input,
[cnet_1_model, canvas_width, canvas_height],
[
cnet_1_image,
cnet_1_output,
canvas_width,
canvas_height,
make_canvas,
],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
make_canvas.click(
update_cn_input,
[cnet_1_model, canvas_width, canvas_height],
[
cnet_1_image,
cnet_1_output,
canvas_width,
canvas_height,
make_canvas,
],
)
cnet_1.click(
fn=(
lambda a, b, s, i: cnet_preview(a, b, 0, s, i)
),
inputs=[
cnet_1_model,
cnet_1_image,
stencils,
images,
],
outputs=[cnet_1_output, stencils, images],
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
with gr.Column():
cnet_2 = gr.Button(
value="Generate controlnet input"
)
cnet_2_model = gr.Dropdown(
label="Controlnet 2",
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
cnet_2_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
show_label=False,
type="pil",
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2_output = gr.Image(
visible=True, show_label=False
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
cnet_2_model.select(
update_cn_input,
[cnet_2_model, canvas_width, canvas_height],
[
cnet_2_image,
cnet_2_output,
canvas_width,
canvas_height,
make_canvas,
],
)
make_canvas.click(
update_cn_input,
[cnet_2_model, canvas_width, canvas_height],
[
cnet_2_image,
cnet_2_output,
canvas_width,
canvas_height,
make_canvas,
],
)
cnet_2.click(
fn=(
lambda a, b, s, i: cnet_preview(a, b, 1, s, i)
),
inputs=[
cnet_2_model,
cnet_2_image,
stencils,
images,
],
outputs=[cnet_2_output, stencils, images],
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Accordion(label="LoRA Options", open=False):
@@ -617,7 +855,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
@@ -625,8 +862,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
ondemand,
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
],
outputs=[
img2img_gallery,
std_output,
img2img_status,
stencils,
images,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)

View File

@@ -122,7 +122,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -234,6 +234,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
@@ -290,8 +291,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
inpaint_init_image = gr.Image(
label="Masked Image",
source="upload",
tool="sketch",
sources="upload",
type="pil",
height=350,
)

View File

@@ -26,6 +26,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():

View File

@@ -104,7 +104,6 @@ with gr.Blocks() as model_web:
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)

View File

@@ -121,7 +121,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -239,6 +239,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():

View File

@@ -84,6 +84,7 @@ with gr.Blocks() as outputgallery_web:
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
show_download_button=False,
)
gallery = gr.Gallery(
@@ -95,7 +96,7 @@ with gr.Blocks() as outputgallery_web:
)
with gr.Column(scale=4):
with gr.Box():
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
@@ -152,6 +153,7 @@ with gr.Blocks() as outputgallery_web:
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
@@ -162,6 +164,12 @@ with gr.Blocks() as outputgallery_web:
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_txt2img_sdxl = gr.Button(
value="Txt2Img XL",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
@@ -195,17 +203,17 @@ with gr.Blocks() as outputgallery_web:
def on_clear_gallery():
return [
gr.Gallery.update(
gr.Gallery(
value=[],
visible=False,
),
gr.Image.update(
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery.update(columns=columns)
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
@@ -215,12 +223,12 @@ with gr.Blocks() as outputgallery_web:
)
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -254,16 +262,16 @@ with gr.Blocks() as outputgallery_web:
)
return [
gr.Dropdown.update(
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -289,12 +297,12 @@ with gr.Blocks() as outputgallery_web:
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -332,12 +340,12 @@ with gr.Blocks() as outputgallery_web:
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
@@ -414,6 +422,7 @@ with gr.Blocks() as outputgallery_web:
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,

View File

@@ -141,7 +141,9 @@ def chat(
prompt_prefix,
history,
model,
device,
backend,
devices,
sharded,
precision,
download_vmfb,
config_file,
@@ -153,7 +155,8 @@ def chat(
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
device, device_id = clean_device_info(device)
device, device_id = clean_device_info(devices[0])
no_of_devices = len(devices)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -173,6 +176,13 @@ def chat(
get_vulkan_target_triple,
)
_extra_args = _extra_args + [
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
@@ -214,7 +224,7 @@ def chat(
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
if sharded:
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
@@ -223,6 +233,7 @@ def chat(
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
n_devices=no_of_devices,
)
else:
# if config_file is None:
@@ -250,13 +261,14 @@ def chat(
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
# for text, msg, exec_time in progress.tqdm(
# vicuna_model.generate(prompt, cli=cli),
# desc="generating response",
# ):
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
if msg is None:
if is_first:
prefill_time = exec_time
prefill_time = exec_time / 1000
is_first = False
else:
total_time_ms += exec_time
@@ -377,6 +389,16 @@ def view_json_file(file_obj):
return content
filtered_devices = dict()
def change_backend(backend):
new_choices = gr.Dropdown(
choices=filtered_devices[backend], label=f"{backend} devices"
)
return new_choices
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
@@ -393,15 +415,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
for x in backend_list:
filtered_devices[x] = [y for y in supported_devices if x in y]
print(filtered_devices)
backend = gr.Radio(
label="backend",
value="cpu",
choices=backend_list,
)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
label="cpu devices",
choices=filtered_devices["cpu"],
interactive=True,
allow_custom_value=True,
# multiselect=True,
multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -425,14 +454,19 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
value=False,
interactive=True,
)
sharded = gr.Checkbox(
label="Shard Model",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button = gr.Button(value="View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
@@ -452,6 +486,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
backend.change(
fn=change_backend,
inputs=[backend],
outputs=[device],
show_progress=False,
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
@@ -464,7 +505,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,
@@ -485,7 +528,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,

View File

@@ -0,0 +1,650 @@
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_sdxl_models,
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_sdxl_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
if precision != "fp16":
print("currently we support fp16 for SDXL")
precision = "fp16"
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = model_id
if custom_vae:
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img_sdxl",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
stencils=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
global_obj.set_sd_obj(
Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image-SDXL",
current_batch + 1,
batch_count,
batch_size,
)
return generated_imgs, text_output, ""
theme = gr.themes.Glass(
primary_hue="slate",
secondary_hue="gray",
)
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_sdxl_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-xl-base-1.0",
choices=predefined_sdxl_models
+ get_custom_model_files(
custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=2,
)
t2i_sdxl_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_sdxl_vae_info = (
f"VAE Path: {t2i_sdxl_vae_info}"
)
custom_vae = gr.Dropdown(
label=f"VAE Models",
info=t2i_sdxl_vae_info,
elem_id="custom_model",
value="None",
choices=[
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_sdxl_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
txt2img_sdxl_autogen = gr.Checkbox(
label="Auto-Generate Images",
value=False,
visible=False,
)
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_sdxl_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_sdxl_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=[
"DDIM",
"EulerAncestralDiscrete",
"EulerDiscrete",
"LCMScheduler",
],
allow_custom_value=False,
visible=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
512,
1024,
value=1024,
step=256,
label="Height",
visible=True,
interactive=True,
)
width = gr.Slider(
512,
1024,
value=1024,
step=256,
label="Width",
visible=True,
interactive=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=77,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="Guidance Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_sdxl_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="scale_down",
)
std_output = gr.Textbox(
value=f"{t2i_sdxl_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_sdxl_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
txt2img_sdxl_sendto_img2img = gr.Button(
value="Send To Img2Img",
visible=False,
)
txt2img_sdxl_sendto_inpaint = gr.Button(
value="Send To Inpaint",
visible=False,
)
txt2img_sdxl_sendto_outpaint = gr.Button(
value="Send To Outpaint",
visible=False,
)
txt2img_sdxl_sendto_upscaler = gr.Button(
value="Send To Upscaler",
visible=False,
)
kwargs = dict(
fn=txt2img_sdxl_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_sdxl_custom_model,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
queue=True,
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
concurrency_limit=1,
)
def autogen_changed(checked):
if checked:
args.autogen = True
else:
args.autogen = False
def check_last_input(prompt):
if not prompt.endswith(" "):
return True
elif not args.autogen:
return True
else:
return False
auto_gen_kwargs = dict(
fn=check_last_input,
inputs=[negative_prompt],
outputs=[txt2img_sdxl_status],
concurrency_limit=1,
)
txt2img_sdxl_autogen.change(
fn=autogen_changed,
inputs=[txt2img_sdxl_autogen],
outputs=None,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[
prompt_submit,
neg_prompt_submit,
generate_click,
],
)
txt2img_sdxl_png_info_img.change(
fn=import_png_metadata,
inputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
)
txt2img_sdxl_custom_model.change(
fn=set_model_default_configs,
inputs=[
txt2img_sdxl_custom_model,
],
outputs=[
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
width,
height,
custom_vae,
txt2img_sdxl_autogen,
],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,4 +1,6 @@
import json
import os
import warnings
import torch
import time
import sys
@@ -35,6 +37,34 @@ from apps.stable_diffusion.src.utils import (
resampler_list,
)
# Names of all interactive fields that can be edited by user
all_gradio_labels = [
"txt2img_custom_model",
"custom_vae",
"prompt",
"negative_prompt",
"lora_weights",
"lora_hf_id",
"scheduler",
"save_metadata_to_png",
"save_metadata_to_json",
"height",
"width",
"steps",
"guidance_scale",
"Low VRAM",
"use_hiresfix",
"resample_type",
"hiresfix_height",
"hiresfix_width",
"hiresfix_strength",
"batch_count",
"batch_size",
"repeatable_seeds",
"seed",
"device",
]
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
@@ -126,7 +156,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -226,7 +256,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil="None",
stencils=[],
ondemand=ondemand,
)
@@ -280,7 +310,8 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
stencils=[],
control_mode=None,
resample_type=resample_type,
)
total_time = time.time() - start_time
@@ -302,7 +333,92 @@ def txt2img_inf(
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image") as txt2img_web:
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# This function export values for all fields that can be edited by user to the settings.json file in ui folder
def export_settings(*values):
settings_list = list(zip(all_gradio_labels, values))
settings = {}
for label, value in settings_list:
settings[label] = value
settings = {"txt2img": settings}
with open("./ui/settings.json", "w") as json_file:
json.dump(settings, json_file, indent=4)
# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder
def load_settings():
try:
with open("./ui/settings.json", "r") as json_file:
loaded_settings = json.load(json_file)["txt2img"]
except (FileNotFoundError, KeyError):
warnings.warn(
"Settings.json file not found or 'txt2img' key is missing. Using default values for fields."
)
loaded_settings = (
{}
) # json file not existing or the data wasn't saved yet
return [
loaded_settings.get(
"txt2img_custom_model",
os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
),
loaded_settings.get(
"custom_vae",
os.path.basename(args.custom_vae) if args.custom_vae else "None",
),
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_hf_id", ""),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
),
loaded_settings.get(
"save_metadata_to_json", args.save_metadata_to_json
),
loaded_settings.get("height", args.height),
loaded_settings.get("width", args.width),
loaded_settings.get("steps", args.steps),
loaded_settings.get("guidance_scale", args.guidance_scale),
loaded_settings.get("Low VRAM", args.ondemand),
loaded_settings.get("use_hiresfix", args.use_hiresfix),
loaded_settings.get("resample_type", args.resample_type),
loaded_settings.get("hiresfix_height", args.hiresfix_height),
loaded_settings.get("hiresfix_width", args.hiresfix_width),
loaded_settings.get("hiresfix_strength", args.hiresfix_strength),
loaded_settings.get("batch_count", args.batch_count),
loaded_settings.get("batch_size", args.batch_size),
loaded_settings.get("repeatable_seeds", args.repeatable_seeds),
loaded_settings.get("seed", args.seed),
loaded_settings.get("device", available_devices[0]),
]
# This function loads the user's exported default settings on the start of program
def onload_load_settings():
loaded_data = load_settings()
structured_data = settings_list = list(zip(all_gradio_labels, loaded_data))
return dict(structured_data)
default_settings = onload_load_settings()
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -314,6 +430,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
@@ -326,9 +443,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
value=default_settings.get(
"txt2img_custom_model"
),
choices=get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
@@ -343,9 +460,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
value=default_settings.get("custom_vae"),
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
@@ -356,20 +471,24 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
tool="None",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
value=default_settings.get("prompt"),
lines=2,
elem_id="prompt_box",
)
# TODO: coming soon
autogen = gr.Checkbox(
label="Continuous Generation",
visible=False,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
value=default_settings.get("negative_prompt"),
lines=2,
elem_id="negative_prompt_box",
)
@@ -384,7 +503,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
@@ -394,7 +513,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
value=default_settings.get("lora_hf_id"),
label="HuggingFace Model ID",
lines=3,
)
@@ -408,33 +527,37 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
value=default_settings.get("scheduler"),
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
value=default_settings.get(
"save_metadata_to_png"
),
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
value=default_settings.get(
"save_metadata_to_json"
),
interactive=True,
)
with gr.Row():
height = gr.Slider(
384,
768,
value=args.height,
value=default_settings.get("height"),
step=8,
label="Height",
)
width = gr.Slider(
384,
768,
value=args.width,
value=default_settings.get("width"),
step=8,
label="Width",
)
@@ -459,18 +582,22 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
1,
100,
value=default_settings.get("steps"),
step=1,
label="Steps",
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
value=default_settings.get("guidance_scale"),
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
value=default_settings.get("Low VRAM"),
label="Low VRAM",
interactive=True,
)
@@ -479,7 +606,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
value=default_settings.get("batch_count"),
step=1,
label="Batch Count",
interactive=True,
@@ -490,23 +617,23 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
4,
value=args.batch_size,
step=1,
label="Batch Size",
label=default_settings.get("batch_size"),
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
default_settings.get("repeatable_seeds"),
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
value=default_settings.get("use_hiresfix"),
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
value=default_settings.get("resample_type"),
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
@@ -514,34 +641,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
value=default_settings.get("hiresfix_height"),
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
value=default_settings.get("hiresfix_width"),
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
value=default_settings.get("hiresfix_strength"),
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
value=default_settings.get("seed"),
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
value=default_settings.get("device"),
choices=available_devices,
allow_custom_value=True,
)
@@ -592,6 +719,75 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
with gr.Row():
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Load Default Settings"
)
export_defaults.click(
fn=load_settings,
inputs=[],
outputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
)
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Export Default Settings"
)
export_defaults.click(
fn=export_settings,
inputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
outputs=[],
)
kwargs = dict(
fn=txt2img_inf,
@@ -680,12 +876,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list,
value="SharkEulerDiscrete",
)

View File

@@ -120,7 +120,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -258,6 +258,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="top_logo",
width=150,
height=50,
show_download_button=False,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():

View File

@@ -4,6 +4,7 @@ import glob
import math
import json
import safetensors
import gradio as gr
from pathlib import Path
from apps.stable_diffusion.src import args
@@ -30,7 +31,7 @@ class Config:
width: int
device: str
use_lora: str
use_stencil: str
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
@@ -64,9 +65,11 @@ scheduler_list_cpu_only = [
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
"LCMScheduler",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
"SharkEulerAncestralDiscrete",
]
predefined_models = [
@@ -87,6 +90,10 @@ predefined_paint_models = [
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
predefined_sdxl_models = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
def resource_path(relative_path):
@@ -140,6 +147,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
)
]
match custom_checkpoint_type:
case "sdxl":
files = [
val
for val in files
if any(x in val for x in ["XL", "xl", "Xl"])
]
case "inpainting":
files = [
val
@@ -247,6 +260,99 @@ def cancel_sd():
pass
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
import gradio as gr
config_modelname = default_config_exists(model_ckpt_or_id)
if jsonconfig:
return get_config_from_json(jsonconfig)
elif config_modelname:
return default_configs[config_modelname]
# TODO: Use HF metadata to setup pipeline if available
# elif is_valid_hf_id(model_ckpt_or_id):
# return get_HF_default_configs(model_ckpt_or_id)
else:
# We don't have default metadata to setup a good config. Do not change configs.
return [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
]
def get_config_from_json(model_ckpt_or_id, jsonconfig):
# TODO: make this work properly. It is currently not user-exposed.
cfgdata = json.load(jsonconfig)
return [
cfgdata["prompt_box_behavior"],
cfgdata["neg_prompt_box_behavior"],
cfgdata["steps"],
cfgdata["scheduler"],
cfgdata["guidance_scale"],
cfgdata["width"],
cfgdata["height"],
cfgdata["custom_vae"],
]
def default_config_exists(model_ckpt_or_id):
if model_ckpt_or_id in [
"stabilityai/sdxl-turbo",
"stabilityai/stable_diffusion-xl-base-1.0",
]:
return model_ckpt_or_id
elif "turbo" in model_ckpt_or_id.lower():
return "stabilityai/sdxl-turbo"
else:
return None
default_configs = {
"stabilityai/sdxl-turbo": [
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
),
gr.Slider(0, 10, value=2),
gr.Dropdown(value="EulerAncestralDiscrete"),
gr.Slider(0, value=0),
512,
512,
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate", visible=False, interactive=True, value=False
),
],
"stabilityai/stable-diffusion-xl-base-1.0": [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
40,
"EulerDiscrete",
7.5,
gr.Slider(value=768, interactive=True),
gr.Slider(value=768, interactive=True),
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
],
}
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()

View File

@@ -26,7 +26,7 @@ diffusers
accelerate
scipy
ftfy
gradio==3.44.3
gradio==4.7.1
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64

View File

@@ -83,7 +83,7 @@ def clean_device_info(raw_device):
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = ""
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
@@ -355,11 +355,15 @@ def get_iree_module(
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config.id = hal_device_id
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
@@ -398,15 +402,16 @@ def load_vmfb_using_mmap(
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
config.id = hal_device_id
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)

View File

@@ -183,6 +183,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
res_vulkan_flag += [
"--iree-stream-resource-max-allocation-size=3221225472"
]
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
@@ -204,7 +207,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}",
f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}"
f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -0,0 +1,62 @@
import unittest
from unittest.mock import mock_open, patch
from apps.stable_diffusion.web.ui.txt2img_ui import (
export_settings,
load_settings,
all_gradio_labels,
)
class TestExportSettings(unittest.TestCase):
@patch("builtins.open", new_callable=mock_open)
@patch("json.dump")
def test_export_settings(self, mock_json_dump, mock_file):
test_values = ["value1", "value2", "value3"]
expected_output = {
"txt2img": {
label: value
for label, value in zip(all_gradio_labels, test_values)
}
}
export_settings(*test_values)
mock_file.assert_called_once_with("./ui/settings.json", "w")
mock_json_dump.assert_called_once_with(
expected_output, mock_file(), indent=4
)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"txt2img": {"some_setting": "some_value"}}',
)
def test_load_settings_file_exists(self, mock_file, mock_json_load):
mock_json_load.return_value = {
"txt2img": {
"txt2img_custom_model": "custom_model_value",
"custom_vae": "custom_vae_value",
}
}
settings = load_settings()
self.assertEqual(settings[0], "custom_model_value")
self.assertEqual(settings[1], "custom_vae_value")
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", side_effect=FileNotFoundError)
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", new_callable=mock_open, read_data="{}")
def test_load_settings_key_error(self, mock_file, mock_json_load):
mock_json_load.return_value = {}
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)