mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
15 Commits
20231202.1
...
20231207.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7159698496 | ||
|
|
7e12d1782a | ||
|
|
bb5f133e1c | ||
|
|
3af0c6c658 | ||
|
|
3322b7264f | ||
|
|
eeb7bdd143 | ||
|
|
2d6f48821d | ||
|
|
c74b55f24e | ||
|
|
1a723645fb | ||
|
|
dfdd3b1f78 | ||
|
|
6384780d16 | ||
|
|
db0c53ae59 | ||
|
|
ce9ce3a7c8 | ||
|
|
d72da3801f | ||
|
|
9c50edc664 |
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
@@ -66,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
# assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
@@ -110,9 +110,11 @@ class LMHeadCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -136,8 +138,9 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -158,15 +161,18 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,10 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.breakpoints is None:
|
||||
is_breakpoint = False
|
||||
else:
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
@@ -188,11 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -202,11 +215,8 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
@@ -216,11 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
@@ -75,6 +76,7 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
@@ -85,4 +87,4 @@ hiddenimports += [
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
|
||||
@@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
|
||||
if "vulkan" in device:
|
||||
_device = args.iree_vulkan_target_triple
|
||||
_device = _device.replace("-", "_")
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
|
||||
if vmfb_path.exists():
|
||||
shark_module = SharkInference(
|
||||
None,
|
||||
@@ -190,9 +190,6 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
self.model_id = "stabilityai/stable-diffusion-2-1-base"
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
@@ -208,6 +205,7 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
self.model_namedata = self.model_name
|
||||
print(f"use_tuned? sharkify: {use_tuned}")
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
@@ -221,7 +219,6 @@ class SharkifyStableDiffusionModel:
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.use_lora = use_lora
|
||||
|
||||
print(self.model_name)
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
self.debug = debug
|
||||
self.sharktank_dir = sharktank_dir
|
||||
@@ -243,7 +240,7 @@ class SharkifyStableDiffusionModel:
|
||||
args.hf_model_id = self.base_model_id
|
||||
self.return_mlir = return_mlir
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
def get_extended_name_for_all_model(self, model_list=None):
|
||||
model_name = {}
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
@@ -254,9 +251,11 @@ class SharkifyStableDiffusionModel:
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
"stencil_adapter",
|
||||
"stencil_adapter_512",
|
||||
]
|
||||
if model_list is not None:
|
||||
sub_model_list = model_list
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -268,11 +267,24 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
# TODO: Fix this
|
||||
# if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
# model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
if "stencil_adapter" in model:
|
||||
stencil_names = []
|
||||
for i, stencil in enumerate(self.stencils):
|
||||
if stencil is not None:
|
||||
cnet_config = (
|
||||
self.model_namedata
|
||||
+ "_sd15_"
|
||||
+ stencil.split("_")[-1]
|
||||
)
|
||||
stencil_names.append(
|
||||
get_extended_name(sub_model + cnet_config)
|
||||
)
|
||||
|
||||
model_name[model] = stencil_names
|
||||
else:
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
@@ -436,24 +448,48 @@ class SharkifyStableDiffusionModel:
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
print(f"Loading default vae, with target {model_id}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
precision = "fp16" if "fp16" in custom_vae else None
|
||||
print(f"Loading custom vae, with target {custom_vae}")
|
||||
if os.path.exists(custom_vae):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
custom_vae = "/".join(
|
||||
[
|
||||
custom_vae.split("/")[-2].split("\\")[-1],
|
||||
custom_vae.split("/")[-1],
|
||||
]
|
||||
)
|
||||
print("Using hub to get custom vae")
|
||||
try:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant=precision,
|
||||
)
|
||||
except:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
print(f"Loading custom vae, with state {custom_vae}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, latents):
|
||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
||||
@@ -465,7 +501,12 @@ class SharkifyStableDiffusionModel:
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
||||
# pipeline.
|
||||
is_f16 = False
|
||||
if not self.custom_vae:
|
||||
is_f16 = False
|
||||
elif "16" in self.custom_vae:
|
||||
is_f16 = True
|
||||
else:
|
||||
is_f16 = False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -502,7 +543,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -650,6 +691,8 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_control_net(self, stencil_id, use_large=False):
|
||||
stencil_id = get_stencil_model_id(stencil_id)
|
||||
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
|
||||
print(f"Importing ControlNet adapter from {stencil_id}")
|
||||
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
|
||||
@@ -658,7 +701,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
self.in_channels = self.cnet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -722,7 +765,25 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
inputs = tuple(self.inputs["stencil_adapter"])
|
||||
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
|
||||
stencil_names = self.get_extended_name_for_all_model([model_name])
|
||||
ext_model_name = stencil_names[model_name]
|
||||
if isinstance(ext_model_name, list):
|
||||
desired_name = None
|
||||
print(ext_model_name)
|
||||
for i in ext_model_name:
|
||||
if stencil_id.split("_")[-1] in i:
|
||||
desired_name = i
|
||||
else:
|
||||
continue
|
||||
if desired_name:
|
||||
ext_model_name = desired_name
|
||||
else:
|
||||
raise Exception(
|
||||
f"Could not find extended configuration for {stencil_id}"
|
||||
)
|
||||
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
@@ -732,19 +793,13 @@ class SharkifyStableDiffusionModel:
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
*inputs[3:],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
|
||||
input_mask = [True, True, True, True] + ([True] * 13)
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name[model_name],
|
||||
extended_model_name=ext_model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -917,11 +972,19 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
try:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant="fp16",
|
||||
)
|
||||
except:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
@@ -1278,16 +1341,16 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def controlnet(self, stencil_id, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
self.inputs["stencil_adapter"] = self.get_input_info_for(
|
||||
base_models["stencil_adapter"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
|
||||
stencil_id, use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
check_compilation(compiled_stencil_adapter, "Stencil")
|
||||
if self.return_mlir:
|
||||
return controlnet_mlir
|
||||
return compiled_stencil_adaptor
|
||||
return compiled_stencil_adapter
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
@@ -123,7 +123,10 @@ def get_clip():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer(subfolder="tokenizer"):
|
||||
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
|
||||
if hf_model_id is not None:
|
||||
args.hf_model_id = hf_model_id
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
|
||||
@@ -158,7 +158,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints=[],
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
|
||||
@@ -25,12 +25,22 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class StencilPipeline(StableDiffusionPipeline):
|
||||
@@ -63,6 +73,24 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
@@ -122,6 +150,58 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def produce_stencil_latents(
|
||||
self,
|
||||
latents,
|
||||
@@ -148,10 +228,16 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
self.load_unet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
use_names = []
|
||||
if name is not None:
|
||||
use_names.append(name)
|
||||
else:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
self.controlnet_names = use_names
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
@@ -213,7 +299,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
continue
|
||||
pass
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
@@ -300,7 +386,6 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
@@ -368,6 +453,17 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
@@ -389,14 +485,32 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
self.sd_model.stencils = stencils
|
||||
for i, hint in enumerate(preprocessed_hints):
|
||||
if hint is not None:
|
||||
hint = controlnet_hint_reshaping(
|
||||
hint,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
stencil_hints.append(hint)
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if stencil == None:
|
||||
continue
|
||||
if len(stencil_hints) > i:
|
||||
if stencil_hints[i] is not None:
|
||||
print(f"Using preprocessed controlnet hint for {stencil}")
|
||||
continue
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
@@ -436,17 +550,30 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
if image is not None:
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
else:
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_stencil_latents(
|
||||
|
||||
@@ -18,7 +18,10 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,10 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
@@ -38,6 +41,7 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
@@ -48,8 +52,10 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -203,10 +209,10 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
# imgs = self.decode_latents_sdxl(None)
|
||||
# all_imgs.extend(imgs)
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
@@ -20,7 +20,10 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae,
|
||||
@@ -52,6 +55,7 @@ class StableDiffusionPipeline:
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
@@ -62,6 +66,7 @@ class StableDiffusionPipeline:
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
@@ -69,14 +74,15 @@ class StableDiffusionPipeline:
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
self.sd_model = sd_model
|
||||
self.scheduler = scheduler
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
@@ -202,6 +208,9 @@ class StableDiffusionPipeline:
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
hf_model_id: Optional[
|
||||
str
|
||||
] = "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -211,7 +220,7 @@ class StableDiffusionPipeline:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
self.tokenizer_2 = get_tokenizer("tokenizer_2")
|
||||
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
|
||||
self.load_clip_sdxl()
|
||||
tokenizers = (
|
||||
[self.tokenizer, self.tokenizer_2]
|
||||
@@ -332,7 +341,7 @@ class StableDiffusionPipeline:
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -523,6 +532,9 @@ class StableDiffusionPipeline:
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
# return None
|
||||
self.status = SD_STATE_IDLE
|
||||
@@ -533,11 +545,22 @@ class StableDiffusionPipeline:
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if isinstance(latents, np.ndarray):
|
||||
latents = torch.tensor(latents)
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
).to(dtype)
|
||||
)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
@@ -549,11 +572,17 @@ class StableDiffusionPipeline:
|
||||
add_time_ids,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
send_to_host=True,
|
||||
)
|
||||
if not isinstance(latents, torch.Tensor):
|
||||
latents = torch.from_numpy(latents).to("cpu")
|
||||
noise_pred = torch.from_numpy(noise_pred).to("cpu")
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
latents = latents.detach().numpy()
|
||||
noise_pred = noise_pred.detach().numpy()
|
||||
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
step_time_sum += step_time
|
||||
@@ -569,11 +598,15 @@ class StableDiffusionPipeline:
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents_sdxl(self, latents):
|
||||
latents = latents.to(torch.float32)
|
||||
def decode_latents_sdxl(self, latents, is_fp32_vae):
|
||||
# latents are in unet dtype here so switch if we want to use fp32
|
||||
if is_fp32_vae:
|
||||
print("Casting latents to float32 for VAE")
|
||||
latents = latents.to(torch.float32)
|
||||
images = self.vae("forward", (latents,))
|
||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
||||
|
||||
@@ -666,6 +699,17 @@ class StableDiffusionPipeline:
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
return cls(
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
# #####################################################
|
||||
@@ -781,7 +825,7 @@ class StableDiffusionPipeline:
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -15,9 +16,21 @@ from diffusers import (
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: Robust scheduler setup on pipeline creation -- if we don't
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
batch_size = 2 if "xl" in model_id.lower() else 1
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerAncestralDiscrete"
|
||||
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
schedulers["SharkEulerDiscrete"].compile(batch_size)
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
|
||||
return schedulers
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
self.init_input_shape = None
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"eulera": {
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"sigma_from": torch.tensor(1).to(torch.float32),
|
||||
"sigma_to": torch.tensor(1).to(torch.float32),
|
||||
"noise": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["eulera"]["latent"]
|
||||
example_output = model_input["eulera"]["output"]
|
||||
example_noise = model_input["eulera"]["noise"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_noise = example_noise.half()
|
||||
example_sigma = model_input["eulera"]["sigma"]
|
||||
example_sigma_from = model_input["eulera"]["sigma_from"]
|
||||
example_sigma_to = model_input["eulera"]["sigma_to"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(
|
||||
example_output,
|
||||
example_latent,
|
||||
example_sigma,
|
||||
example_sigma_from,
|
||||
example_sigma_to,
|
||||
example_noise,
|
||||
),
|
||||
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_step_"
|
||||
+ self.config.prediction_type
|
||||
+ args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
step_inputs = []
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
step_inputs = [
|
||||
noise_pred,
|
||||
latent,
|
||||
sigma,
|
||||
sigma_from,
|
||||
sigma_to,
|
||||
noise,
|
||||
]
|
||||
# TODO: deal with dynamic inputs in turbine flow.
|
||||
# update step index since we're done with the variable and will return with compiled module output.
|
||||
self._step_index += 1
|
||||
|
||||
if noise_pred.shape[0] < self.batch_size:
|
||||
for i in [0, 1, 5]:
|
||||
try:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i])
|
||||
except:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
|
||||
step_inputs[i] = torch.cat(
|
||||
(step_inputs[i], step_inputs[i]), axis=0
|
||||
)
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -2,12 +2,9 @@ import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
timestep_spacing,
|
||||
timestep_type,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = 1
|
||||
|
||||
def compile(self):
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = latent - sigma_hat * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepSampleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
@@ -90,16 +123,22 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
"sample": SchedulerStepSampleModel(),
|
||||
"original_sample": SchedulerStepSampleModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
@@ -109,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
|
||||
else:
|
||||
try:
|
||||
step_model_type = (
|
||||
"sample"
|
||||
if "sample" in self.config.prediction_type
|
||||
else self.config.prediction_type
|
||||
)
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
@@ -116,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + args.precision,
|
||||
"euler_step_" + step_model_type + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
@@ -127,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
@@ -138,15 +183,61 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = (
|
||||
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigma <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
|
||||
if gamma > 0:
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
sigma_hat = sigma
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
sigma_hat,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
|
||||
@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
|
||||
@@ -189,6 +189,49 @@
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
@@ -233,7 +276,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adaptor": {
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
@@ -449,4 +492,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
|
||||
@@ -467,6 +467,13 @@ p.add_argument(
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -20,9 +20,7 @@ def save_img(img):
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
subdir = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints")
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(
|
||||
@@ -60,7 +58,7 @@ def HWC3(x):
|
||||
return y
|
||||
|
||||
|
||||
def controlnet_hint_shaping(
|
||||
def controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt=1
|
||||
):
|
||||
channels = 3
|
||||
@@ -79,10 +77,12 @@ def controlnet_hint_shaping(
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
|
||||
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
|
||||
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint.detach().numpy()),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
@@ -109,29 +109,38 @@ def controlnet_hint_shaping(
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
|
||||
+ f"({height}, {width}, {channels}), "
|
||||
+ f"(1, {height}, {width}, {channels}) or "
|
||||
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
controlnet_hint = np.array(controlnet_hint) # to numpy
|
||||
controlnet_hint = np.array(controlnet_hint).astype(
|
||||
np.float16
|
||||
) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, num_images_per_prompt
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
|
||||
)
|
||||
(hint_w, hint_h) = controlnet_hint.size
|
||||
left = int((hint_w - width) / 2)
|
||||
right = left + height
|
||||
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
|
||||
controlnet_hint = controlnet_hint.resize((width, height))
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -141,20 +150,26 @@ def controlnet_hint_conversion(
|
||||
controlnet_hint = None
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print("Detecting edge with canny")
|
||||
print(
|
||||
"Converting controlnet hint to edge detection mask with canny preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print("Detecting human pose")
|
||||
print(
|
||||
"Detecting human pose in controlnet hint with openpose preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
print("Using your scribble as a controlnet hint.")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
print(
|
||||
"Converting controlnet hint to a depth mapping with ZoeDepth."
|
||||
)
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
controlnet_hint = controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
return controlnet_hint
|
||||
|
||||
@@ -30,9 +30,15 @@ class ZoeDetector:
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
model.load_state_dict(
|
||||
torch.load(modelpath, map_location=model.device)["model"]
|
||||
)
|
||||
|
||||
# Hack to fix the ZoeDepth import issue
|
||||
model_keys = model.state_dict().keys()
|
||||
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
|
||||
loaded_keys = loaded_dict.keys()
|
||||
for key in loaded_keys - model_keys:
|
||||
loaded_dict.pop(key)
|
||||
|
||||
model.load_state_dict(loaded_dict)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
|
||||
@@ -1008,8 +1008,7 @@ def get_generation_text_info(seeds, device):
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
def resize_stencil(image: Image.Image, width, height):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
@@ -97,8 +97,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
@@ -110,11 +108,15 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# SDXL
|
||||
txt2img_sdxl_inf,
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
@@ -151,7 +153,7 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_outpaint,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -165,6 +167,7 @@ if __name__ == "__main__":
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
@@ -178,7 +181,7 @@ if __name__ == "__main__":
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -189,7 +192,7 @@ if __name__ == "__main__":
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -199,12 +202,14 @@ if __name__ == "__main__":
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
) as sd_web:
|
||||
@@ -241,6 +246,7 @@ if __name__ == "__main__":
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
txt2img_sdxl_status,
|
||||
]
|
||||
)
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
@@ -249,17 +255,17 @@ if __name__ == "__main__":
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
|
||||
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
|
||||
txt2img_sdxl_web.render()
|
||||
|
||||
actual_port = app.usable_port()
|
||||
@@ -399,6 +405,12 @@ if __name__ == "__main__":
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_sdxl_png_info_img, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
|
||||
@@ -16,6 +16,11 @@ from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
@@ -83,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
|
||||
@@ -129,10 +129,6 @@ body {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
color: transparent;
|
||||
background-color: transparent;
|
||||
@@ -140,6 +136,10 @@ body {
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
@@ -234,11 +234,6 @@ footer {
|
||||
display:none;
|
||||
}
|
||||
|
||||
/* Hide the download icon from the nod logo */
|
||||
#top_logo button {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* workarounds for container=false not currently working for dropdowns */
|
||||
.dropdown_no_container {
|
||||
padding: 0 !important;
|
||||
|
||||
@@ -6,6 +6,12 @@ import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorData,
|
||||
EditorValue,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -34,6 +40,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
@@ -74,6 +81,7 @@ def img2img_inf(
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -97,18 +105,23 @@ def img2img_inf(
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
return None, "A stencil must have an Image input"
|
||||
continue
|
||||
if images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if image_dict is None:
|
||||
if image_dict is None and images[0] is None:
|
||||
return None, "An Initial Image is required"
|
||||
# if use_stencil == "scribble":
|
||||
# image = image_dict["mask"].convert("RGB")
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
# TODO: enable t2i + controlnets
|
||||
image = None
|
||||
if image:
|
||||
image, _, _ = resize_stencil(image, width, height)
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
@@ -140,9 +153,7 @@ def img2img_inf(
|
||||
if stencil is not None:
|
||||
stencil_count += 1
|
||||
if stencil_count > 0:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
# image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
@@ -152,6 +163,7 @@ def img2img_inf(
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
args.precision = precision
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
print(stencils)
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
@@ -170,7 +182,12 @@ def img2img_inf(
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
or any(
|
||||
global_obj.get_cfg_obj().stencils[idx] != stencil
|
||||
for idx, stencil in enumerate(stencils)
|
||||
)
|
||||
):
|
||||
print("clearing config because you changed something important")
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
@@ -186,7 +203,7 @@ def img2img_inf(
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
else "runwayml/stable-diffusion-v1-5"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
@@ -269,6 +286,7 @@ def img2img_inf(
|
||||
images,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
preprocessed_hints=preprocessed_hints,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -299,6 +317,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
STENCIL_COUNT = 2
|
||||
stencils = gr.State([None] * STENCIL_COUNT)
|
||||
images = gr.State([None] * STENCIL_COUNT)
|
||||
preprocessed_hints = gr.State([None] * STENCIL_COUNT)
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -310,6 +329,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
@@ -363,110 +383,421 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
height=300,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Multistencil Options", open=False):
|
||||
choices = ["None", "canny", "openpose", "scribble"]
|
||||
choices = [
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
|
||||
def cnet_preview(
|
||||
checked, model, input_image, index, stencils, images
|
||||
model,
|
||||
input_image,
|
||||
index,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
):
|
||||
if not checked:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
return (None, stencils, images)
|
||||
if isinstance(input_image, PIL.Image.Image):
|
||||
img_dict = {
|
||||
"background": None,
|
||||
"layers": [None],
|
||||
"composite": input_image,
|
||||
}
|
||||
input_image = EditorValue(img_dict)
|
||||
images[index] = input_image
|
||||
stencils[index] = model
|
||||
if model:
|
||||
stencils[index] = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
result = canny(np.array(input_image), 100, 200)
|
||||
result = canny(
|
||||
np.array(input_image["composite"]),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result
|
||||
)
|
||||
return (
|
||||
[Image.fromarray(result), result],
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
result = openpose(np.array(input_image))
|
||||
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
|
||||
result = openpose(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result[0]
|
||||
)
|
||||
return (
|
||||
[Image.fromarray(result[0]), result],
|
||||
Image.fromarray(result[0]),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "zoedepth":
|
||||
zoedepth = ZoeDetector()
|
||||
result = zoedepth(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "scribble":
|
||||
preprocessed_hints[index] = input_image[
|
||||
"composite"
|
||||
]
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case _:
|
||||
return (None, stencils, images)
|
||||
preprocessed_hints[index] = None
|
||||
return (
|
||||
None,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
|
||||
def import_original(original_img, width, height):
|
||||
resized_img, _, _ = resize_stencil(
|
||||
original_img, width, height
|
||||
)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
"composite": None,
|
||||
}
|
||||
return gr.ImageEditor(
|
||||
value=EditorValue(img_dict),
|
||||
crop_size=(width, height),
|
||||
)
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
np.zeros(
|
||||
shape=(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
+ 255
|
||||
)
|
||||
img_dict = {
|
||||
"background": data,
|
||||
"layers": [data],
|
||||
"composite": None,
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
def update_cn_input(
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
index,
|
||||
):
|
||||
if model == None:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
preprocessed_hints[index] = None
|
||||
return [
|
||||
gr.ImageEditor(value=None, visible=False),
|
||||
gr.Image(value=None),
|
||||
gr.Slider(visible=False),
|
||||
gr.Slider(visible=False),
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
elif model == "scribble":
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
brush=Brush(
|
||||
colors=["#000000"],
|
||||
color_mode="fixed",
|
||||
default_size=2,
|
||||
),
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Canvas Width"),
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
interactive=True,
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Input Width"),
|
||||
gr.Slider(visible=True, label="Input Height"),
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=True),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
|
||||
with gr.Row():
|
||||
cnet_1 = gr.Checkbox(show_label=False)
|
||||
cnet_1_model = gr.Dropdown(
|
||||
label="Controlnet 1",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
cnet_1_image = gr.Image(
|
||||
source="upload",
|
||||
tool=None,
|
||||
with gr.Column():
|
||||
cnet_1 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_1_model = gr.Dropdown(
|
||||
label="Controlnet 1",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img_1 = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
cnet_1_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
)
|
||||
cnet_1_output = gr.Gallery(
|
||||
show_label=False,
|
||||
object_fit="scale-down",
|
||||
rows=1,
|
||||
columns=1,
|
||||
cnet_1_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
)
|
||||
cnet_1.change(
|
||||
|
||||
use_input_img_1.click(
|
||||
import_original,
|
||||
[img2img_init_image, canvas_width, canvas_height],
|
||||
[cnet_1_image],
|
||||
)
|
||||
|
||||
cnet_1_model.change(
|
||||
fn=(
|
||||
lambda a, b, c, s, i: cnet_preview(
|
||||
a, b, c, 0, s, i
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_1_image,
|
||||
cnet_1_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img_1,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_1_image,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_1.click],
|
||||
fn=(
|
||||
lambda a, b, s, i, p: cnet_preview(
|
||||
a, b, 0, s, i, p
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1,
|
||||
cnet_1_model,
|
||||
cnet_1_image,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_1_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[cnet_1_output, stencils, images],
|
||||
)
|
||||
with gr.Row():
|
||||
cnet_2 = gr.Checkbox(show_label=False)
|
||||
cnet_2_model = gr.Dropdown(
|
||||
label="Controlnet 2",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
cnet_2_image = gr.Image(
|
||||
source="upload",
|
||||
tool=None,
|
||||
with gr.Column():
|
||||
cnet_2 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_2_model = gr.Dropdown(
|
||||
label="Controlnet 2",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img_2 = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
cnet_2_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
type="pil",
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
)
|
||||
cnet_2_output = gr.Gallery(
|
||||
show_label=False,
|
||||
object_fit="scale-down",
|
||||
rows=1,
|
||||
columns=1,
|
||||
use_input_img_2.click(
|
||||
import_original,
|
||||
[img2img_init_image, canvas_width, canvas_height],
|
||||
[cnet_2_image],
|
||||
)
|
||||
cnet_2.change(
|
||||
cnet_2_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
)
|
||||
cnet_2_model.change(
|
||||
fn=(
|
||||
lambda a, b, c, s, i: cnet_preview(
|
||||
a, b, c, 1, s, i
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_2_image,
|
||||
cnet_2_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img_2,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_2_image,
|
||||
],
|
||||
)
|
||||
cnet_2.click(
|
||||
fn=(
|
||||
lambda a, b, s, i, p: cnet_preview(
|
||||
a, b, 1, s, i, p
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2,
|
||||
cnet_2_model,
|
||||
cnet_2_image,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_2_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[cnet_2_output, stencils, images],
|
||||
)
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
@@ -688,6 +1019,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
|
||||
@@ -234,6 +234,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
@@ -290,8 +291,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
|
||||
inpaint_init_image = gr.Image(
|
||||
label="Masked Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
sources="upload",
|
||||
type="pil",
|
||||
height=350,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
|
||||
@@ -104,7 +104,6 @@ with gr.Blocks() as model_web:
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
|
||||
@@ -84,6 +84,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
show_download_button=False,
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
@@ -95,7 +96,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
with gr.Column(
|
||||
scale=15,
|
||||
@@ -152,6 +153,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
@@ -162,6 +164,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
outputgallery_sendto_txt2img_sdxl = gr.Button(
|
||||
value="Txt2Img XL",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
@@ -195,17 +203,17 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_image_columns_change(columns):
|
||||
return gr.Gallery.update(columns=columns)
|
||||
return gr.Gallery(columns=columns)
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
@@ -215,12 +223,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -254,16 +262,16 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
gr.Dropdown(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -289,12 +297,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -332,12 +340,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
@@ -414,6 +422,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
|
||||
@@ -141,7 +141,9 @@ def chat(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
backend,
|
||||
devices,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -153,7 +155,8 @@ def chat(
|
||||
global vicuna_model
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device, device_id = clean_device_info(device)
|
||||
device, device_id = clean_device_info(devices[0])
|
||||
no_of_devices = len(devices)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
@@ -173,6 +176,13 @@ def chat(
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
_extra_args = _extra_args + [
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
|
||||
"--iree-opt-const-eval=false",
|
||||
"--iree-opt-data-tiling=false",
|
||||
]
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
@@ -214,7 +224,7 @@ def chat(
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
if sharded:
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
@@ -223,6 +233,7 @@ def chat(
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
n_devices=no_of_devices,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
@@ -250,13 +261,14 @@ def chat(
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
# for text, msg, exec_time in progress.tqdm(
|
||||
# vicuna_model.generate(prompt, cli=cli),
|
||||
# desc="generating response",
|
||||
# ):
|
||||
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
prefill_time = exec_time / 1000
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
@@ -377,6 +389,16 @@ def view_json_file(file_obj):
|
||||
return content
|
||||
|
||||
|
||||
filtered_devices = dict()
|
||||
|
||||
|
||||
def change_backend(backend):
|
||||
new_choices = gr.Dropdown(
|
||||
choices=filtered_devices[backend], label=f"{backend} devices"
|
||||
)
|
||||
return new_choices
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
@@ -393,15 +415,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
|
||||
for x in backend_list:
|
||||
filtered_devices[x] = [y for y in supported_devices if x in y]
|
||||
print(filtered_devices)
|
||||
|
||||
backend = gr.Radio(
|
||||
label="backend",
|
||||
value="cpu",
|
||||
choices=backend_list,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
label="cpu devices",
|
||||
choices=filtered_devices["cpu"],
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -425,14 +454,19 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
sharded = gr.Checkbox(
|
||||
label="Shard Model",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button = gr.Button(value="View as JSON", visible=False)
|
||||
json_view = gr.JSON(visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
@@ -452,6 +486,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
backend.change(
|
||||
fn=change_backend,
|
||||
inputs=[backend],
|
||||
outputs=[device],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
@@ -464,7 +505,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -485,7 +528,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
|
||||
@@ -11,9 +11,11 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
predefined_sdxl_models,
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -50,17 +52,17 @@ def txt2img_sdxl_inf(
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
@@ -71,6 +73,10 @@ def txt2img_sdxl_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -93,13 +99,15 @@ def txt2img_sdxl_inf(
|
||||
else:
|
||||
args.hf_model_id = model_id
|
||||
|
||||
# if custom_vae != "None":
|
||||
# args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
if custom_vae:
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = ""
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -115,7 +123,7 @@ def txt2img_sdxl_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -144,31 +152,29 @@ def txt2img_sdxl_inf(
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
# For SDXL we set max_length as 77.
|
||||
print("Setting max_length = 77")
|
||||
max_length = 77
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
)
|
||||
global_obj.set_sd_obj(txt2img_sdxl_obj)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
@@ -220,7 +226,12 @@ def txt2img_sdxl_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
theme = gr.themes.Glass(
|
||||
primary_hue="slate",
|
||||
secondary_hue="gray",
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -232,6 +243,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
@@ -239,7 +251,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_sdxl_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
@@ -247,36 +259,104 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=[
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
],
|
||||
choices=predefined_sdxl_models
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
t2i_sdxl_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_vae_info = (
|
||||
f"VAE Path: {t2i_sdxl_vae_info}"
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"VAE Models",
|
||||
info=t2i_sdxl_vae_info,
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
choices=[
|
||||
None,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_sdxl_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
txt2img_sdxl_autogen = gr.Checkbox(
|
||||
label="Auto-Generate Images",
|
||||
value=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_sdxl_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_sdxl_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=["DDIM"],
|
||||
value="EulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"EulerAncestralDiscrete",
|
||||
"EulerDiscrete",
|
||||
"LCMScheduler",
|
||||
],
|
||||
allow_custom_value=True,
|
||||
visible=False,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -291,31 +371,34 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
step=256,
|
||||
label="Height",
|
||||
visible=False,
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=8,
|
||||
step=256,
|
||||
label="Width",
|
||||
visible=False,
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
value=77,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
@@ -333,7 +416,7 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
label="Guidance Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
@@ -357,12 +440,14 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
@@ -391,10 +476,10 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
object_fit="scale_down",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_model_info}\n"
|
||||
value=f"{t2i_sdxl_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
@@ -413,7 +498,22 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
txt2img_sdxl_sendto_img2img = gr.Button(
|
||||
value="Send To Img2Img",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_inpaint = gr.Button(
|
||||
value="Send To Inpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_outpaint = gr.Button(
|
||||
value="Send To Outpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_upscaler = gr.Button(
|
||||
value="Send To Upscaler",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_sdxl_inf,
|
||||
@@ -429,24 +529,55 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_sdxl_custom_model,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
queue=True,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_sdxl_status,
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
def autogen_changed(checked):
|
||||
if checked:
|
||||
args.autogen = True
|
||||
else:
|
||||
args.autogen = False
|
||||
|
||||
def check_last_input(prompt):
|
||||
if not prompt.endswith(" "):
|
||||
return True
|
||||
elif not args.autogen:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
auto_gen_kwargs = dict(
|
||||
fn=check_last_input,
|
||||
inputs=[negative_prompt],
|
||||
outputs=[txt2img_sdxl_status],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
txt2img_sdxl_autogen.change(
|
||||
fn=autogen_changed,
|
||||
inputs=[txt2img_sdxl_autogen],
|
||||
outputs=None,
|
||||
)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
@@ -454,5 +585,66 @@ with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
cancels=[
|
||||
prompt_submit,
|
||||
neg_prompt_submit,
|
||||
generate_click,
|
||||
],
|
||||
)
|
||||
|
||||
txt2img_sdxl_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
txt2img_sdxl_custom_model.change(
|
||||
fn=set_model_default_configs,
|
||||
inputs=[
|
||||
txt2img_sdxl_custom_model,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
width,
|
||||
height,
|
||||
custom_vae,
|
||||
txt2img_sdxl_autogen,
|
||||
],
|
||||
)
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
@@ -35,6 +37,34 @@ from apps.stable_diffusion.src.utils import (
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
# Names of all interactive fields that can be edited by user
|
||||
all_gradio_labels = [
|
||||
"txt2img_custom_model",
|
||||
"custom_vae",
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"lora_weights",
|
||||
"lora_hf_id",
|
||||
"scheduler",
|
||||
"save_metadata_to_png",
|
||||
"save_metadata_to_json",
|
||||
"height",
|
||||
"width",
|
||||
"steps",
|
||||
"guidance_scale",
|
||||
"Low VRAM",
|
||||
"use_hiresfix",
|
||||
"resample_type",
|
||||
"hiresfix_height",
|
||||
"hiresfix_width",
|
||||
"hiresfix_strength",
|
||||
"batch_count",
|
||||
"batch_size",
|
||||
"repeatable_seeds",
|
||||
"seed",
|
||||
"device",
|
||||
]
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
@@ -281,6 +311,7 @@ def txt2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
stencils=[],
|
||||
control_mode=None,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
@@ -302,7 +333,92 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
|
||||
# This function export values for all fields that can be edited by user to the settings.json file in ui folder
|
||||
def export_settings(*values):
|
||||
settings_list = list(zip(all_gradio_labels, values))
|
||||
settings = {}
|
||||
|
||||
for label, value in settings_list:
|
||||
settings[label] = value
|
||||
|
||||
settings = {"txt2img": settings}
|
||||
with open("./ui/settings.json", "w") as json_file:
|
||||
json.dump(settings, json_file, indent=4)
|
||||
|
||||
|
||||
# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder
|
||||
def load_settings():
|
||||
try:
|
||||
with open("./ui/settings.json", "r") as json_file:
|
||||
loaded_settings = json.load(json_file)["txt2img"]
|
||||
except (FileNotFoundError, KeyError):
|
||||
warnings.warn(
|
||||
"Settings.json file not found or 'txt2img' key is missing. Using default values for fields."
|
||||
)
|
||||
loaded_settings = (
|
||||
{}
|
||||
) # json file not existing or the data wasn't saved yet
|
||||
|
||||
return [
|
||||
loaded_settings.get(
|
||||
"txt2img_custom_model",
|
||||
os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
),
|
||||
loaded_settings.get(
|
||||
"custom_vae",
|
||||
os.path.basename(args.custom_vae) if args.custom_vae else "None",
|
||||
),
|
||||
loaded_settings.get("prompt", args.prompts[0]),
|
||||
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
|
||||
loaded_settings.get("lora_weights", "None"),
|
||||
loaded_settings.get("lora_hf_id", ""),
|
||||
loaded_settings.get("scheduler", args.scheduler),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_png", args.write_metadata_to_png
|
||||
),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_json", args.save_metadata_to_json
|
||||
),
|
||||
loaded_settings.get("height", args.height),
|
||||
loaded_settings.get("width", args.width),
|
||||
loaded_settings.get("steps", args.steps),
|
||||
loaded_settings.get("guidance_scale", args.guidance_scale),
|
||||
loaded_settings.get("Low VRAM", args.ondemand),
|
||||
loaded_settings.get("use_hiresfix", args.use_hiresfix),
|
||||
loaded_settings.get("resample_type", args.resample_type),
|
||||
loaded_settings.get("hiresfix_height", args.hiresfix_height),
|
||||
loaded_settings.get("hiresfix_width", args.hiresfix_width),
|
||||
loaded_settings.get("hiresfix_strength", args.hiresfix_strength),
|
||||
loaded_settings.get("batch_count", args.batch_count),
|
||||
loaded_settings.get("batch_size", args.batch_size),
|
||||
loaded_settings.get("repeatable_seeds", args.repeatable_seeds),
|
||||
loaded_settings.get("seed", args.seed),
|
||||
loaded_settings.get("device", available_devices[0]),
|
||||
]
|
||||
|
||||
|
||||
# This function loads the user's exported default settings on the start of program
|
||||
def onload_load_settings():
|
||||
loaded_data = load_settings()
|
||||
structured_data = settings_list = list(zip(all_gradio_labels, loaded_data))
|
||||
return dict(structured_data)
|
||||
|
||||
|
||||
default_settings = onload_load_settings()
|
||||
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -314,6 +430,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
@@ -326,9 +443,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
value=default_settings.get(
|
||||
"txt2img_custom_model"
|
||||
),
|
||||
choices=get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
@@ -343,9 +460,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label=f"VAE Models",
|
||||
info=t2i_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
value=default_settings.get("custom_vae"),
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
@@ -356,20 +471,24 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
tool="None",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
value=default_settings.get("prompt"),
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
# TODO: coming soon
|
||||
autogen = gr.Checkbox(
|
||||
label="Continuous Generation",
|
||||
visible=False,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
value=default_settings.get("negative_prompt"),
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
@@ -384,7 +503,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
value=default_settings.get("lora_weights"),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -394,7 +513,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
value=default_settings.get("lora_hf_id"),
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
@@ -408,33 +527,37 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
value=default_settings.get("scheduler"),
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
value=default_settings.get(
|
||||
"save_metadata_to_png"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
value=default_settings.get(
|
||||
"save_metadata_to_json"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.height,
|
||||
value=default_settings.get("height"),
|
||||
step=8,
|
||||
label="Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.width,
|
||||
value=default_settings.get("width"),
|
||||
step=8,
|
||||
label="Width",
|
||||
)
|
||||
@@ -459,18 +582,22 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
1,
|
||||
100,
|
||||
value=default_settings.get("steps"),
|
||||
step=1,
|
||||
label="Steps",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
value=default_settings.get("guidance_scale"),
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
value=default_settings.get("Low VRAM"),
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
@@ -479,7 +606,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
value=default_settings.get("batch_count"),
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
@@ -490,23 +617,23 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
label=default_settings.get("batch_size"),
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
default_settings.get("repeatable_seeds"),
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Accordion(label="Hires Fix Options", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
value=default_settings.get("use_hiresfix"),
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
value=default_settings.get("resample_type"),
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=False,
|
||||
@@ -514,34 +641,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
value=default_settings.get("hiresfix_height"),
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
value=default_settings.get("hiresfix_width"),
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
value=default_settings.get("hiresfix_strength"),
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
value=default_settings.get("seed"),
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
value=default_settings.get("device"),
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -592,6 +719,75 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
txt2img_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
export_defaults = gr.Button(
|
||||
value="Load Default Settings"
|
||||
)
|
||||
export_defaults.click(
|
||||
fn=load_settings,
|
||||
inputs=[],
|
||||
outputs=[
|
||||
txt2img_custom_model,
|
||||
custom_vae,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
ondemand,
|
||||
use_hiresfix,
|
||||
resample_type,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
batch_count,
|
||||
batch_size,
|
||||
repeatable_seeds,
|
||||
seed,
|
||||
device,
|
||||
],
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
export_defaults = gr.Button(
|
||||
value="Export Default Settings"
|
||||
)
|
||||
export_defaults.click(
|
||||
fn=export_settings,
|
||||
inputs=[
|
||||
txt2img_custom_model,
|
||||
custom_vae,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
ondemand,
|
||||
use_hiresfix,
|
||||
resample_type,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
batch_count,
|
||||
batch_size,
|
||||
repeatable_seeds,
|
||||
seed,
|
||||
device,
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
@@ -680,12 +876,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
|
||||
def set_compatible_schedulers(hires_fix_selected):
|
||||
if hires_fix_selected:
|
||||
return gr.Dropdown.update(
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list_cpu_only,
|
||||
value="DEISMultistep",
|
||||
)
|
||||
else:
|
||||
return gr.Dropdown.update(
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list,
|
||||
value="SharkEulerDiscrete",
|
||||
)
|
||||
|
||||
@@ -258,6 +258,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
|
||||
@@ -4,6 +4,7 @@ import glob
|
||||
import math
|
||||
import json
|
||||
import safetensors
|
||||
import gradio as gr
|
||||
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
@@ -64,9 +65,11 @@ scheduler_list_cpu_only = [
|
||||
"DPMSolverSinglestep",
|
||||
"DDPM",
|
||||
"HeunDiscrete",
|
||||
"LCMScheduler",
|
||||
]
|
||||
scheduler_list = scheduler_list_cpu_only + [
|
||||
"SharkEulerDiscrete",
|
||||
"SharkEulerAncestralDiscrete",
|
||||
]
|
||||
|
||||
predefined_models = [
|
||||
@@ -87,6 +90,10 @@ predefined_paint_models = [
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
]
|
||||
predefined_sdxl_models = [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -140,6 +147,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "sdxl":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if any(x in val for x in ["XL", "xl", "Xl"])
|
||||
]
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
@@ -247,6 +260,96 @@ def cancel_sd():
|
||||
pass
|
||||
|
||||
|
||||
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
|
||||
import gradio as gr
|
||||
|
||||
config_modelname = default_config_exists(model_ckpt_or_id)
|
||||
if jsonconfig:
|
||||
return get_config_from_json(jsonconfig)
|
||||
elif config_modelname:
|
||||
return default_configs[config_modelname]
|
||||
# TODO: Use HF metadata to setup pipeline if available
|
||||
# elif is_valid_hf_id(model_ckpt_or_id):
|
||||
# return get_HF_default_configs(model_ckpt_or_id)
|
||||
else:
|
||||
# We don't have default metadata to setup a good config. Do not change configs.
|
||||
return [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_config_from_json(model_ckpt_or_id, jsonconfig):
|
||||
# TODO: make this work properly. It is currently not user-exposed.
|
||||
cfgdata = json.load(jsonconfig)
|
||||
return [
|
||||
cfgdata["prompt_box_behavior"],
|
||||
cfgdata["neg_prompt_box_behavior"],
|
||||
cfgdata["steps"],
|
||||
cfgdata["scheduler"],
|
||||
cfgdata["guidance_scale"],
|
||||
cfgdata["width"],
|
||||
cfgdata["height"],
|
||||
cfgdata["custom_vae"],
|
||||
]
|
||||
|
||||
|
||||
def default_config_exists(model_ckpt_or_id):
|
||||
if model_ckpt_or_id in default_configs.keys():
|
||||
return model_ckpt_or_id
|
||||
elif "turbo" in model_ckpt_or_id.lower():
|
||||
return "stabilityai/sdxl-turbo"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
default_configs = {
|
||||
"stabilityai/sdxl-turbo": [
|
||||
gr.Textbox(label="", interactive=False, value=None, visible=False),
|
||||
gr.Textbox(
|
||||
label="Prompt",
|
||||
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
|
||||
),
|
||||
gr.Slider(0, 10, value=2),
|
||||
"EulerAncestralDiscrete",
|
||||
gr.Slider(0, value=0),
|
||||
512,
|
||||
512,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate", visible=False, interactive=True, value=False
|
||||
),
|
||||
],
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
40,
|
||||
"EulerDiscrete",
|
||||
7.5,
|
||||
gr.Slider(value=768, interactive=True),
|
||||
gr.Slider(value=768, interactive=True),
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
@@ -26,7 +26,7 @@ diffusers
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio==3.44.3
|
||||
gradio==4.7.1
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
|
||||
@@ -83,7 +83,7 @@ def clean_device_info(raw_device):
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = ""
|
||||
device_id = None
|
||||
if device in ["rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
@@ -355,11 +355,15 @@ def get_iree_module(
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
@@ -398,15 +402,16 @@ def load_vmfb_using_mmap(
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
|
||||
@@ -183,6 +183,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
@@ -204,7 +207,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}",
|
||||
f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}"
|
||||
f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
62
shark/tests/test_txt2img_ui.py
Normal file
62
shark/tests/test_txt2img_ui.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
export_settings,
|
||||
load_settings,
|
||||
all_gradio_labels,
|
||||
)
|
||||
|
||||
|
||||
class TestExportSettings(unittest.TestCase):
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@patch("json.dump")
|
||||
def test_export_settings(self, mock_json_dump, mock_file):
|
||||
test_values = ["value1", "value2", "value3"]
|
||||
expected_output = {
|
||||
"txt2img": {
|
||||
label: value
|
||||
for label, value in zip(all_gradio_labels, test_values)
|
||||
}
|
||||
}
|
||||
|
||||
export_settings(*test_values)
|
||||
mock_file.assert_called_once_with("./ui/settings.json", "w")
|
||||
mock_json_dump.assert_called_once_with(
|
||||
expected_output, mock_file(), indent=4
|
||||
)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data='{"txt2img": {"some_setting": "some_value"}}',
|
||||
)
|
||||
def test_load_settings_file_exists(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {
|
||||
"txt2img": {
|
||||
"txt2img_custom_model": "custom_model_value",
|
||||
"custom_vae": "custom_vae_value",
|
||||
}
|
||||
}
|
||||
|
||||
settings = load_settings()
|
||||
self.assertEqual(settings[0], "custom_model_value")
|
||||
self.assertEqual(settings[1], "custom_vae_value")
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
|
||||
settings = load_settings()
|
||||
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="{}")
|
||||
def test_load_settings_key_error(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {}
|
||||
|
||||
settings = load_settings()
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
Reference in New Issue
Block a user