Complete SD pipeline.

This commit is contained in:
Ean Garvey
2023-12-17 23:51:34 -06:00
parent b0151a77de
commit 58130432ab
8 changed files with 435 additions and 213 deletions

View File

@@ -4,8 +4,10 @@ import time
import os
import json
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
@@ -16,6 +18,8 @@ from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddin
from apps.shark_studio.modules.img_processing import (
resize_stencil,
save_output_img,
resamplers,
resampler_list,
)
from apps.shark_studio.modules.ckpt_processing import (
@@ -30,8 +34,7 @@ sd_model_map = {
"clip": {
"initializer": clip.export_clip_model,
"external_weight_file": None,
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
],
"ireec_flags": ["--iree-flow-collapse-reduction-dims"],
},
"vae_encode": {
"initializer": vae.export_vae_model,
@@ -48,6 +51,10 @@ sd_model_map = {
"vae_decode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
"ireec_flags": ["--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
},
}
@@ -78,15 +85,15 @@ class StableDiffusion(SharkPipelineBase):
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_img2img: bool = False,
is_controlled: bool = False,
):
self.model_max_length = 77
self.batch_size = batch_size
self.precision = precision
self.is_img2img = is_img2img
self.dtype = torch.float16 if precision == "fp16" else torch.float32
self.height = height
self.width = width
self.scheduler_obj = {}
self.precision = precision
static_kwargs = {
"pipe": {},
"clip": {"hf_model_name": base_model_id},
@@ -98,6 +105,8 @@ class StableDiffusion(SharkPipelineBase):
#"num_loras": num_loras,
"height": height,
"width": width,
"precision": precision,
"max_length": 77 * 8,
},
"vae_encode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
@@ -105,13 +114,15 @@ class StableDiffusion(SharkPipelineBase):
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
"vae_decode": {
"hf_model_name": custom_vae,
"hf_model_name": custom_vae if custom_vae else base_model_id,
"vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
}
super().__init__(
@@ -135,26 +146,26 @@ class StableDiffusion(SharkPipelineBase):
gc.collect()
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings):
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
print(
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
f"\n[LOG] Custom embeddings currently unsupported."
)
self.is_img2img = is_img2img
schedulers = get_schedulers(self.base_model_id)
self.weights_path = get_checkpoints_path(self.pipe_id)
self.weights_path = get_checkpoints_path(self.safe_name(self.pipe_id))
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
# accepting a list of schedulers in batched cases.
for i in scheduler:
self.scheduler_obj[i] = schedulers[i]
print(f"[LOG] Loaded scheduler: {i}")
self.scheduler = schedulers[scheduler]
print(f"[LOG] Loaded scheduler: {scheduler}")
for model in adapters:
self.model_map[model] = adapters[model]
if os.path.isfile(custom_weights):
for i in self.model_map:
self.model_map[i]["external_weights_file"] = None
elif custom_weights != "":
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
if custom_weights:
if os.path.isfile(custom_weights):
for i in self.model_map:
self.model_map[i]["external_weights_file"] = None
elif custom_weights:
print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
self.static_kwargs["pipe"] = {
# "external_weight_path": self.weights_path,
# "external_weights": "safetensors",
@@ -162,6 +173,92 @@ class StableDiffusion(SharkPipelineBase):
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return
def generate_images(
self,
prompt,
negative_prompt,
image,
steps,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
use_base_vae,
resample_type,
control_mode,
hints,
):
#TODO: Batched args
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
else:
image = None
print("\n[LOG] Generating images...")
batched_args=[
prompt,
negative_prompt,
#steps,
#strength,
#guidance_scale,
#seed,
#resample_type,
#control_mode,
#hints,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
text_embeddings = self.encode_prompts_weight(
prompt,
negative_prompt,
)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
init_latents, final_timesteps = self.prepare_latents(
generator=generator,
num_inference_steps=steps,
image=image,
strength=strength,
)
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
)
# Img latents -> PIL images
all_imgs = []
self.load_submodels(["vae_decode"])
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + self.batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=True,
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_submodels(["vae_decode"])
return all_imgs
def encode_prompts_weight(
@@ -191,84 +288,220 @@ class StableDiffusion(SharkPipelineBase):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
pad = pad + (0, 77 * 8 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
self.unload_submodels(["clip"])
gc.collect()
print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}")
return text_embeddings.numpy().astype(np.float16)
def generate_images(
def prepare_latents(
self,
prompt,
negative_prompt,
steps,
generator,
num_inference_steps,
image,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
):
print("\n[LOG] Generating images...")
batched_args=[
prompt,
negative_prompt,
steps,
strength,
guidance_scale,
seed,
resample_type,
control_mode,
hints,
]
for arg in batched_args:
if not isinstance(arg, list):
arg = [arg] * self.batch_size
if len(arg) < self.batch_size:
arg = arg * self.batch_size
else:
arg = [arg[i] for i in range(self.batch_size)]
noise = torch.randn(
(
self.batch_size,
4,
self.height // 8,
self.width // 8,
),
generator=generator,
dtype=self.dtype,
).to("cpu")
self.scheduler.set_timesteps(num_inference_steps)
if self.is_img2img:
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
latents = self.encode_image(image)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, [timesteps]
else:
self.scheduler.is_scale_input_called = True
latents = noise * self.scheduler.init_noise_sigma
return latents, self.scheduler.timesteps
text_embeddings = self.encode_prompts_weight(
prompt,
negative_prompt,
)
print(text_embeddings)
test_img = [
Image.open(
get_resource_path("../../tests/jupiter.png"), mode="r"
).convert("RGB")
] * self.batch_size
return test_img
def encode_image(self, input_image):
self.load_submodels(["vae_encode"])
vae_encode_start = time.time()
latents = self.run("vae_encode", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_submodels(["vae_encode"])
print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}")
return latents
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
cpu_scheduling,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# self.status = SD_STATE_IDLE
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(torch.float16)
text_embeddings_numpy = text_embeddings.detach().numpy()
guidance_scale = np.asarray([guidance_scale], dtype=np.float16)
self.load_submodels(["unet"])
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16),
mask,
masked_image_latents,
],
dim=1,
).to(self.dtype)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
# profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
# end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.run("scheduler_step", (noise_pred, t, latents))
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
#if self.status == SD_STATE_CANCEL:
# break
if self.ondemand:
self.unload_submodels(["unet"])
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
print(f"\n[LOG] Average step time: {avg_step_time}ms/it")
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents.to(self.dtype)
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
#profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.run("vae_decode", latents_numpy).to_host()
vae_inf_time = (time.time() - vae_start) * 1000
#end_profiling(profile_device)
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("[LOG] Submitting Request...")
input_imgs = []
img_paths = sd_kwargs["sd_init_image"]
for img_path in img_paths:
if img_path:
if os.path.isfile(img_path):
input_imgs.append(
Image.open(img_path, mode="r").convert("RGB")
)
sd_kwargs["sd_init_image"] = input_imgs
# result = shark_sd_fn(**sd_kwargs)
# for i in range(sd_kwargs["batch_count"]):
# yield from result
# return result
for key in sd_kwargs:
if sd_kwargs[key] in ["None", "", None, []]:
sd_kwargs[key] = None
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
for i in range(1):
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
yield generated_imgs
@@ -278,7 +511,7 @@ def shark_sd_fn_dict_input(
def shark_sd_fn(
prompt,
negative_prompt,
sd_init_image,
sd_init_image: list,
height: int,
width: int,
steps: int,
@@ -291,6 +524,7 @@ def shark_sd_fn(
base_model_id: str,
custom_weights: str,
custom_vae: str,
use_base_vae: bool,
precision: str,
device: str,
ondemand: bool,
@@ -300,20 +534,9 @@ def shark_sd_fn(
embeddings: dict,
):
sd_kwargs = locals()
if isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
(
image,
_,
_,
) = resize_stencil(image, width, height)
is_img2img = True
is_img2img = True if sd_init_image[0] is not None else False
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
@@ -358,7 +581,6 @@ def shark_sd_fn(
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": cmd_opts.import_mlir,
"is_img2img": is_img2img,
"is_controlled": is_controlled,
}
submit_prep_kwargs = {
@@ -366,16 +588,19 @@ def shark_sd_fn(
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"use_base_vae": use_base_vae,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
@@ -410,13 +635,9 @@ def shark_sd_fn(
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
try:
this_seed = seed[current_batch]
except:
this_seed = seed[0]
save_output_img(
out_imgs[0],
this_seed,
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
@@ -438,6 +659,7 @@ def view_json_file(file_path):
return content
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj

View File

@@ -75,9 +75,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps'][0]},"
f"Sampler: {extra_info['scheduler'][0]}, "
f"CFG scale: {extra_info['guidance_scale'][0]}, "
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "

View File

@@ -1,5 +1,10 @@
from msvcrt import kbhit
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
@@ -32,8 +37,8 @@ class SharkPipelineBase:
self.model_map = model_map
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.device_name = device
self.device = device.split("=>")[-1].strip(" ")
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tempfiles = {}
@@ -46,11 +51,11 @@ class SharkPipelineBase:
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = pipe_id
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
print("\n[LOG] Checking for pre-compiled artifacts.")
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
@@ -58,10 +63,12 @@ class SharkPipelineBase:
ireec_flags = []
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
print(f"\n[LOG] Executable for {submodel} already loaded...")
return
elif "vmfb_path" in self.model_map[submodel]:
return
elif submodel not in self.tempfiles:
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
@@ -90,16 +97,6 @@ class SharkPipelineBase:
return
def hijack_weights(self, weights_path, submodel="None"):
if submodel == "None":
for i in self.model_map:
self.hijack_weights(weights_path, i)
else:
if submodel in self.iree_module_dict:
self.model_map[submodel]["external_weights_file"] = weights_path
return
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
@@ -112,33 +109,10 @@ class SharkPipelineBase:
break
for file in vmfbs:
if submodel in file:
print(f"Found existing .vmfb at {file}")
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
os.path.join(vmfbs_path, file),
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel]['external_weight_file'],
)
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
return
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
@@ -160,18 +134,53 @@ class SharkPipelineBase:
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
if "vmfb_path" in self.model_map[submodel]:
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
)
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.model_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel]['external_weight_file'],
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return
def run(self, submodel, inputs):
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
if not isinstance(inputs, list):
inputs = [inputs]
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs]
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
def safe_name(name):
return name.replace("/", "_").replace("-", "_")
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args

View File

@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np
re_attention = re.compile(
r"""
@@ -161,7 +162,7 @@ def pad_tokens_and_weights(
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = 8
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
@@ -194,13 +195,16 @@ def pad_tokens_and_weights(
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input: torch.Tensor,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
@@ -214,7 +218,7 @@ def get_unweighted_text_embeddings(
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
if no_boseos_middle:
if i == 0:
@@ -231,50 +235,14 @@ def get_unweighted_text_embeddings(
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return torch.from_numpy(text_embeddings.to_host())
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = 8
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]
print(text_input_chunk)
breakpoint()
text_embedding = pipe.run("clip", text_input_chunk)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
@@ -286,7 +254,7 @@ def get_weighted_text_embeddings(
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = False,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
@@ -325,12 +293,12 @@ def get_weighted_text_embeddings(
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights

View File

@@ -18,7 +18,7 @@ from diffusers import (
def get_schedulers(model_id):
#TODO: switch over to turbine and run all on GPU
print(f"[LOG] Initializing schedulers from model id: {model_id}")
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,

View File

@@ -1,23 +1,24 @@
{
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
"sd_init_image": [ "None" ],
"sd_init_image": [ null ],
"height": 512,
"width": 512,
"steps": [ 50 ],
"strength": [ 0.8 ],
"guidance_scale": [ 7.5 ],
"seed": [ -1 ],
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": -1,
"batch_count": 1,
"batch_size": 1,
"scheduler": [ "EulerDiscrete" ],
"scheduler": "EulerDiscrete",
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": "",
"custom_vae": "",
"custom_weights": null,
"custom_vae": null,
"use_base_vae": false,
"precision": "fp16",
"device": "vulkan",
"ondemand": "False",
"repeatable_seeds": "False",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}

View File

@@ -41,6 +41,14 @@ from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
sd_default_models = [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-1.0",
"stabilityai/sdxl-turbo",
]
def view_json_file(file_path):
content = ""
@@ -105,6 +113,7 @@ def pull_sd_configs(
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,
@@ -120,11 +129,6 @@ def pull_sd_configs(
"prompt",
"negative_prompt",
"sd_init_image",
"steps",
"strength",
"guidance_scale",
"seed",
"scheduler",
]:
sd_cfg[arg] = [sd_args[arg]]
elif arg in ["controlnets", "embeddings"]:
@@ -144,8 +148,10 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
if os.path.isfile(sd_json["sd_init_image"][0]):
sd_image = Image.open(sd_json["sd_init_image"][0], mode="r")
for i in sd_json["sd_init_image"]:
if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
@@ -155,16 +161,17 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"][0],
sd_json["strength"][0],
sd_json["steps"],
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"][0],
sd_json["seed"],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"][0],
sd_json["scheduler"],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["use_base_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["ondemand"],
@@ -320,7 +327,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_model_map.keys(),
choices=sd_default_models,
) # base_model_id
custom_weights = gr.Dropdown(
label="Custom Weights",
@@ -328,7 +335,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=get_checkpoints(base_model_id),
choices=["None"] + get_checkpoints(base_model_id),
) #
with gr.Column(scale=2):
sd_vae_info = (
@@ -361,6 +368,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
],
visible=True,
)
use_base_vae = gr.Checkbox(
value=False,
label="Baked VAE",
interactive=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
@@ -639,7 +651,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(scale=3, min_width=600):
with gr.Group():
sd_gallery = gr.Gallery(
label="Generated images",
label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
@@ -719,6 +731,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,
@@ -753,6 +766,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
base_model_id,
custom_weights,
custom_vae,
use_base_vae,
precision,
device,
ondemand,

View File

@@ -65,6 +65,14 @@ def get_iree_device_args(device, extra_args=[]):
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def get_iree_target_triple(device):
args = get_iree_device_args(device)
for flag in args:
if "triple" in flag.split("-"):
triple = flag.split("=")
return triple
return ""
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline